diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 2e8b0970f..d86cc2ba6 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,11 +1,6 @@ [bumpversion] -current_version = 0.1.113 +current_version = 2.19.2 commit = True tag = True [bumpversion:file:pychunkedgraph/__init__.py] - -[bumpversion:file:pychunkedgraph/app/cg_app_blueprint.py] - -[bumpversion:file:pychunkedgraph/app/meshing_app_blueprint.py] - diff --git a/.coveragerc b/.coveragerc index 661a357cc..a38e1c392 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,7 +1,6 @@ # .coveragerc to control coverage.py [run] branch = True -concurrency = "multiprocessing" source = pychunkedgraph omit = *test* diff --git a/.dockerignore b/.dockerignore index dd7b29495..66349c8a4 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,13 +1,124 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class -build -dist -__pycache__ -.pytest_cache -.tox -*.egg-info -*.egg/ -*.pyc -*.swp +# C extensions +*.so +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +.idea/* + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ .coverage -.vscode +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# Visual Code +.vscode/ + +# terraform +.terraform/ +*.lock.hcl +*.tfstate +*.tfstate.* + + +# local dev stuff +.devcontainer/ +*.ipynb +*.rdb +/protobuf* + +# Git +.git/ \ No newline at end of file diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 000000000..899f0431f --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,22 @@ +name: PyChunkedGraph + +on: + push: + branches: + - "main" + pull_request: + branches: + - "main" + +jobs: + unit-tests: + runs-on: ubuntu-latest + steps: + - name: Check out code + uses: actions/checkout@v2 + + - name: Build image and run tests + run: | + docker build --tag seunglab/pychunkedgraph:$GITHUB_SHA . + docker run --rm seunglab/pychunkedgraph:$GITHUB_SHA /bin/sh -c "pytest --cov-config .coveragerc --cov=pychunkedgraph ./pychunkedgraph/tests && codecov" + diff --git a/.gitignore b/.gitignore index ef8107b32..498253791 100644 --- a/.gitignore +++ b/.gitignore @@ -23,7 +23,6 @@ wheels/ *.egg-info/ .installed.cfg *.egg -MANIFEST .idea/* # PyInstaller @@ -108,7 +107,16 @@ venv.bak/ # Visual Code .vscode/ +# terraform +.terraform/ +*.lock.hcl +*.tfstate +*.tfstate.* + # local dev stuff -output.txt -src/ \ No newline at end of file +.devcontainer/ +*.ipynb +*.rdb +/protobuf* +.DS_Store \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index 392ef925d..a5e33242d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,6 +2,10 @@ sudo: true services: docker +env: + global: + - CLOUDSDK_CORE_DISABLE_PROMPTS=1 + stages: - test - name: merge-deploy @@ -21,6 +25,11 @@ jobs: - ci_env=`bash <(curl -s https://codecov.io/env)` script: + - openssl aes-256-cbc -K $encrypted_506e835c2891_key -iv $encrypted_506e835c2891_iv -in key.json.enc -out key.json -d + - curl https://sdk.cloud.google.com | bash > /dev/null + - source "$HOME/google-cloud-sdk/path.bash.inc" + - gcloud auth activate-service-account --key-file=key.json + - gcloud auth configure-docker - docker build --tag seunglab/pychunkedgraph:$TRAVIS_BRANCH . || travis_terminate 1 - docker run $ci_env --rm seunglab/pychunkedgraph:$TRAVIS_BRANCH /bin/sh -c "tox -v -- --cov-config .coveragerc --cov=pychunkedgraph && codecov" diff --git a/Dockerfile b/Dockerfile index 48b20aac0..2b7eeb151 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,62 +1,11 @@ -FROM tiangolo/uwsgi-nginx-flask:python3.6 - +FROM caveconnectome/pychunkedgraph:base_042124 +ENV VIRTUAL_ENV=/app/venv +ENV PATH="$VIRTUAL_ENV/bin:$PATH" +COPY override/gcloud /app/venv/bin/gcloud COPY override/timeout.conf /etc/nginx/conf.d/timeout.conf COPY override/supervisord.conf /etc/supervisor/conf.d/supervisord.conf -COPY requirements.txt /app -RUN mkdir -p /home/nginx/.cloudvolume/secrets \ - && chown -R nginx /home/nginx \ - && usermod -d /home/nginx -s /bin/bash nginx \ - && apt-get update \ - && apt-get install -y \ - # Boost and g++ for compiling DracoPy and graph_tool - build-essential \ - libboost-dev \ - # Required for adding graph-tools and cloud-sdk to the apt source list - lsb-release \ - curl \ - apt-transport-https \ - # GOOGLE-CLOUD-SDK - && pip install --no-cache-dir --upgrade crcmod \ - && echo "deb https://packages.cloud.google.com/apt cloud-sdk-$(lsb_release -c -s) main" > /etc/apt/sources.list.d/google-cloud-sdk.list \ - && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add - \ - && apt-get update \ - && apt-get install -y google-cloud-sdk google-cloud-sdk-bigtable-emulator \ - # PYTHON-GRAPH-TOOL - # WARNING: This is ugly, graph-tools will use Debian's Python version and install as dist-util, - # but all our packages use the site-util Python version - we just create a sym_link, - # because it _seems_ to work and saves 80 minutes compilation time ... - && echo "deb http://downloads.skewed.de/apt/$(lsb_release -s -c) $(lsb_release -s -c) main" > /etc/apt/sources.list.d/graph-tool.list \ - && echo "deb-src http://downloads.skewed.de/apt/$(lsb_release -s -c) $(lsb_release -s -c) main" >> /etc/apt/sources.list.d/graph-tool.list \ - && apt-key adv --no-tty --keyserver hkp://keyserver.ubuntu.com --recv-key 612DEFB798507F25 \ - && apt-get update \ - && apt-get install -y python3-graph-tool \ - && ln -s /usr/lib/python3/dist-packages/graph_tool /usr/local/lib/python3.6/site-packages/graph_tool \ - && pip install --no-cache-dir --upgrade scipy \ - # PYCHUNKEDGRAPH - # Need pip 18.1 for process-dependency-links flag support - && pip install --no-cache-dir pip==18.1 \ - # Need numpy to prevent install issue with cloud-volume / fpzip - && pip install --no-cache-dir --upgrade numpy \ - && pip install --no-cache-dir --upgrade --process-dependency-links -r requirements.txt \ - # Tests - && pip install tox codecov \ - # CLEANUP - # libboost-dev and build-essentials will be required by tox to build python dependencies - && apt-get remove --purge -y lsb-release curl \ - && apt-get autoremove --purge -y \ - && rm -rf /var/lib/apt/lists/* \ - && find /usr/local/lib/python3* -depth \ - \( \ - \( -type d -a \( -name __pycache__ \) \) \ - -o \ - \( -type f -a \( -name '*.pyc' -o -name '*.pyo' \) \) \ - \) -exec rm -rf '{}' + \ - && find /usr/lib/python3* -depth \ - \( \ - \( -type d -a \( -name __pycache__ \) \) \ - -o \ - \( -type f -a \( -name '*.pyc' -o -name '*.pyo' \) \) \ - \) -exec rm -rf '{}' + -COPY . /app \ No newline at end of file +COPY requirements.txt . +RUN pip install --upgrade -r requirements.txt +COPY . /app diff --git a/README.md b/README.md index 10cf16b18..ef888b3c6 100644 --- a/README.md +++ b/README.md @@ -4,27 +4,54 @@ [![Build Status](https://travis-ci.org/seung-lab/PyChunkedGraph.svg?branch=master)](https://travis-ci.org/seung-lab/PyChunkedGraph) [![codecov](https://codecov.io/gh/seung-lab/PyChunkedGraph/branch/master/graph/badge.svg)](https://codecov.io/gh/seung-lab/PyChunkedGraph) -The PyChunkedGraph is a proofreading and segmentation data management backend with (not limited to) the following features: -- Concurrent proofreading by multiple users without restrictions on the workflow -- Continuous versioning of proofreading edits -- Making changes visible to all users immediately -- Local mincut computations +The PyChunkedGraph is a proofreading and segmentation data management backend powering FlyWire and other proofreading platforms. It builds on an initial agglomeration of supervoxels and facilitates fast and parallel editing of connected components in the agglomeration graph by many users. -## Scaling to large datasets +## PyChunkedGraph versions +The main branch represents the second version (v2) of the PyChunkedGraph implementation. The first version (v1) is still maintained and can be found under `pcgv1`. The v2 implementation resolved data storage concerns and removed some scaling bottlenecks in the implementation. Any new dataset should use the v2. -## Deployment +## Using the PyChunkedGraph -While the PyChunkedGraph can be deployed as a stand-alone entity, we deploy it within our [annotation infrastructure](https://github.com/seung-lab/AnnotationPipelineOverview) which uses a CI Kubernetes deployment. We will make instructions and scripts for a mostly automated deployment to Google Cloud available soon. +The ChunkedGraph is built on Google Cloud BigTable. A BigTable instance is required to use this ChunkedGraph implementation. -## Building your own PyChunkedGraph +### Environmental Variables +There are three environmental variables that need to be set +to connect to a chunkgraph: +- `GOOGLE_APPLICATION_CREDENTIALS`: Location of the google-secret.json file. +- `BIGTABLE_PROJECT`: Name of the Google Cloud project name. +- `BIGTABLE_INSTANCE`: Name of the Bigtable Instance ID. (Default is 'pychunkedgraph') +### Ingest + +`/ingest` provides examples for ingest scripts. The ingestion pipeline designed to use the output of the seunglab's agglomeration pipeline but can be adjusted to use alternative data sources. + +### Deployment / Import + +The PyChunkedGraph can be locally deployed (`run_dev.py`), imported in a python script (`from pychunkedgraph.backend import chunkedgraph`) or deployed on a kubernetes server. Deployment code for a kubernetes server on Google Cloud is not included in this repository. Please feel free to reach out if you are interested in that. ## System Design As a backend the PyChunkedGraph can be combined with any frontend that adheres to its API. We use an adapted version of [neuroglancer](https://github.com/seung-lab/neuroglancer/tree/nkem-multicut) which is publicly available. -[system_design]: https://github.com/seung-lab/PyChunkedGraph/blob/master/ProofreadingDiagram.png "System Design" -![alt text][system_design] +## Publication + +When using or referencing the PyChunkedGraph, please use the citation below. The FlyWire paper described and published the PyChunkedGraph v1. + +[FlyWire: Online community for whole-brain connectomics](https://www.nature.com/articles/s41592-021-01330-0) +``` +@article{FlyWire2021, + doi = {10.1038/s41592-021-01330-0}, + url = {https://doi.org/10.1038/s41592-021-01330-0}, + year = {2021}, + month = dec, + publisher = {Springer Science and Business Media {LLC}}, + volume = {19}, + number = {1}, + pages = {119--128}, + author = {Sven Dorkenwald and Claire E. McKellar and Thomas Macrina and Nico Kemnitz and Kisuk Lee and Ran Lu and Jingpeng Wu and Sergiy Popovych and Eric Mitchell and Barak Nehoran and Zhen Jia and J. Alexander Bae and Shang Mu and Dodam Ih and Manuel Castro and Oluwaseun Ogedengbe and Akhilesh Halageri and Kai Kuehner and Amy R. Sterling and Zoe Ashwood and Jonathan Zung and Derrick Brittain and Forrest Collman and Casey Schneider-Mizell and Chris Jordan and William Silversmith and Christa Baker and David Deutsch and Lucas Encarnacion-Rivera and Sandeep Kumar and Austin Burke and Doug Bland and Jay Gager and James Hebditch and Selden Koolman and Merlin Moore and Sarah Morejohn and Ben Silverman and Kyle Willie and Ryan Willie and Szi-chieh Yu and Mala Murthy and H. Sebastian Seung}, + title = {{FlyWire}: online community for whole-brain connectomics}, + journal = {Nature Methods} +} +``` diff --git a/base.Dockerfile b/base.Dockerfile new file mode 100644 index 000000000..b5123e137 --- /dev/null +++ b/base.Dockerfile @@ -0,0 +1,70 @@ +ARG PYTHON_VERSION=3.11 +ARG BASE_IMAGE=tiangolo/uwsgi-nginx-flask:python${PYTHON_VERSION} + + +###################################################### +# Build Image - PCG dependencies +###################################################### +FROM ${BASE_IMAGE} AS pcg-build +ENV PATH="/root/miniconda3/bin:${PATH}" +ENV CONDA_ENV="pychunkedgraph" + +# Setup Miniconda +RUN apt-get update && apt-get install build-essential wget -y +RUN wget \ + https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ + && mkdir /root/.conda \ + && bash Miniconda3-latest-Linux-x86_64.sh -b \ + && rm -f Miniconda3-latest-Linux-x86_64.sh \ + && conda update conda + +# Install PCG dependencies - especially graph-tool +# Note: uwsgi has trouble with pip and python3.11, so adding this with conda, too +COPY requirements.txt . +COPY requirements.yml . +COPY requirements-dev.txt . +RUN conda env create -n ${CONDA_ENV} -f requirements.yml + +# Shrink conda environment into portable non-conda env +RUN conda install conda-pack -c conda-forge + +RUN conda-pack -n ${CONDA_ENV} --ignore-missing-files -o /tmp/env.tar \ + && mkdir -p /app/venv \ + && cd /app/venv \ + && tar xf /tmp/env.tar \ + && rm /tmp/env.tar +RUN /app/venv/bin/conda-unpack + + +###################################################### +# Build Image - Bigtable Emulator (without Google SDK) +###################################################### +FROM golang:bullseye as bigtable-emulator-build +RUN mkdir -p /usr/src +WORKDIR /usr/src +ENV GOOGLE_CLOUD_GO_VERSION bigtable/v1.19.0 +RUN apt-get update && apt-get install git -y +RUN git clone --depth=1 --branch="$GOOGLE_CLOUD_GO_VERSION" https://github.com/googleapis/google-cloud-go.git . \ + && cd bigtable \ + && go install -v ./cmd/emulator + + +###################################################### +# Production Image +###################################################### +FROM ${BASE_IMAGE} +ENV VIRTUAL_ENV=/app/venv +ENV PATH="$VIRTUAL_ENV/bin:$PATH" + +COPY --from=pcg-build /app/venv /app/venv +COPY --from=bigtable-emulator-build /go/bin/emulator /app/venv/bin/cbtemulator +COPY override/gcloud /app/venv/bin/gcloud +COPY override/timeout.conf /etc/nginx/conf.d/timeout.conf +COPY override/supervisord.conf /etc/supervisor/conf.d/supervisord.conf +# Hack to get zstandard from PyPI - remove if conda-forge linked lib issue is resolved +RUN pip install --no-cache-dir --no-deps --force-reinstall zstandard==0.21.0 +COPY . /app + +RUN mkdir -p /home/nginx/.cloudvolume/secrets \ + && chown -R nginx /home/nginx \ + && usermod -d /home/nginx -s /bin/bash nginx diff --git a/build_pypi.sh b/build_pypi.sh old mode 100755 new mode 100644 diff --git a/cloudbuild.yaml b/cloudbuild.yaml index 66c045fbc..21f4cc58d 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -1,18 +1,28 @@ steps: -# - name: 'gcr.io/cloud-builders/docker' -# args: [ 'build', '-t', 'gcr.io/$PROJECT_ID/pychunkedgraph', '.' ] -- name: 'gcr.io/cloud-builders/docker' - entrypoint: 'bash' - args: - - '-c' - - | - docker build -t gcr.io/$PROJECT_ID/pychunkedgraph:$TAG_NAME . - timeout: 600s -#- name: 'gcr.io/cloud-builders/docker' -# entrypoint: 'bash' -# args: -# - '-c' -# - | -# [[ "$BRANCH_NAME" == "master" ]] && docker build -t gcr.io/$PROJECT_ID/pychunkedgraph . -images: -- 'gcr.io/$PROJECT_ID/pychunkedgraph:$TAG_NAME' + # Login to Docker Hub + - name: "gcr.io/cloud-builders/docker" + entrypoint: "bash" + args: ["-c", "docker login --username=$$USERNAME --password=$$PASSWORD"] + secretEnv: ["USERNAME", "PASSWORD"] + + - name: "gcr.io/cloud-builders/docker" + entrypoint: "bash" + args: + - "-c" + - | + docker build -t $$USERNAME/pychunkedgraph:$TAG_NAME . + timeout: 600s + secretEnv: ["USERNAME"] + + # Push the final image to Dockerhub + - name: "gcr.io/cloud-builders/docker" + entrypoint: "bash" + args: ["-c", "docker push $$USERNAME/pychunkedgraph:$TAG_NAME"] + secretEnv: ["USERNAME"] + +availableSecrets: + secretManager: + - versionName: projects/$PROJECT_ID/secrets/docker-password/versions/1 + env: "PASSWORD" + - versionName: projects/$PROJECT_ID/secrets/docker-username/versions/1 + env: "USERNAME" diff --git a/compile_reqs.sh b/compile_reqs.sh new file mode 100755 index 000000000..2d74c225d --- /dev/null +++ b/compile_reqs.sh @@ -0,0 +1 @@ +docker run -v ${PWD}:/app caveconnectome/pychunkedgraph:v2.4.0 /bin/bash -c "pip install pip-tools && pip-compile requirements.in --resolver=backtracking -v --output-file requirements.txt" \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index 96ccc0919..000000000 --- a/docker-compose.yml +++ /dev/null @@ -1,56 +0,0 @@ -version: '3.7' - -services: - - pcg_1: - build: - context: . - image: pcg - container_name: pcg_1 - environment: - - APP_SETTINGS=pychunkedgraph.app.config.DeploymentWithRedisConfig - - FLASK_APP=run_dev_cli.py - - REDIS_SERVICE_HOST=redis - - REDIS_SERVICE_PORT=6379 - - REDIS_PASSWORD=dev - volumes: - - .:/app - - ~/secrets:/root/.cloudvolume/secrets - ports: - - '80:80' - - '4000:4000' - depends_on: - - redis - - pcg_2: - image: pcg - container_name: pcg_2 - environment: - - APP_SETTINGS=pychunkedgraph.app.config.DeploymentWithRedisConfig - - FLASK_APP=run_dev_cli.py - - REDIS_SERVICE_HOST=redis - - REDIS_SERVICE_PORT=6379 - - REDIS_PASSWORD=dev - volumes: - - .:/app - - ~/secrets:/root/.cloudvolume/secrets - ports: - - '81:80' - - '4001:4000' - depends_on: - - pcg_1 - - redis - - redis: - image: redis:5.0.4-alpine - container_name: redis - ports: - - '6379:6379' - command: ["redis-server", "--requirepass", "dev"] - - rq-dashboard: - image: python:3.6.8-alpine - container_name: rq-dashboard - ports: - - '9181:9181' - command: [sh, -c, "pip install rq-dashboard && rq-dashboard -u redis://:dev@redis:6379"] \ No newline at end of file diff --git a/pychunkedgraph/backend/Readme.md b/docs/Readme.md similarity index 92% rename from pychunkedgraph/backend/Readme.md rename to docs/Readme.md index 20d428bef..45799326e 100644 --- a/pychunkedgraph/backend/Readme.md +++ b/docs/Readme.md @@ -138,12 +138,12 @@ id_history = cg.read_agglomeration_id_history(root_id) ### Read the local ChunkedGraph -To read the edge list and edge affinities of all atmic super voxels belonging to a root node do +To read the edge list and edge affinities of all atomic super voxels belonging to a root node do ``` -cg.get_subgraph(root_id, bounding_box, bb_is_coordinate=True) +cg.get_subgraph(root_id, bounding_box, bbox_is_coordinate=True) ``` -The user can define a `bounding_box=[[x_l, y_l, z_l], [x_h, y_h, z_h]]` as either coordinates or chunk id range (use `bb_is_coordinate`). The `bounding_box` feature is currently not working, the parameter is ignored. The current datset is small enough to all reads of atomic supervoxels. Hence, this should not hinder any development. +The user can define a `bounding_box=[[x_l, y_l, z_l], [x_h, y_h, z_h]]` as either coordinates or chunk id range (use `bbox_is_coordinate`). The `bounding_box` feature is currently not working, the parameter is ignored. The current datset is small enough to all reads of atomic supervoxels. Hence, this should not hinder any development. ##### Minimal example ``` diff --git a/docs/edges.md b/docs/edges.md new file mode 100644 index 000000000..ca483423e --- /dev/null +++ b/docs/edges.md @@ -0,0 +1,75 @@ +## Serialization + +PyChunkedgraph uses protobuf for serialization and zstandard for compression. + +Edges and connected components per chunk are stored using the protobuf definitions in [`pychunkedgraph.io.protobuf`](https://github.com/seung-lab/PyChunkedGraph/pychunkedgraph/io/protobuf/chunkEdges.proto). +This format is a result of performance tests. +It provided the best tradeoff between deserialzation speed and storage size. + +To read and write edges in this format, the functions `get_chunk_edges` and `put_chunk_edges` +in the module `pychunkedgraph.io.edges` may be used. + +[CloudVolume](https://github.com/seung-lab/cloud-volume) is used for uploading and downloading this data. + +### Edges + +Edges in chunkedgraph refer to edges between supervoxels (group of voxels). +These supervoxels are the atomic nodes in the graph, they cannot be split. + +There are three types of edges in a chunk: +1. `in` edge between supervoxels within chunk boundary +2. `between` edge between supervoxels in adjacent chunks +3. `cross` a faux edge between parts of the same supervoxel that has been split across chunk boundary + +### Components + +A component is simply a mapping of supervoxel to it's connected component. +Components within a single chunk are stored as a numpy array. +``` +[ + component1_size, + supervoxel_a, + supervoxel_b, + supervoxel_c, + component2_size, + supervoxel_x, + supervoxel_y, + ... +] +``` + +### Example usage + +``` +from pychunkedgraph.io.edges import get_chunk_edges +from pychunkedgraph.io.edges import put_chunk_edges +from pychunkedgraph.graph.edges import Edges +from pychunkedgraph.graph.edges import EDGE_TYPES + +in_chunk = [[1,2],[2,3],[0,2],[2,4]] +between_chunk = [[1,5]] +cross_chunk = [[3,6]] + +in_chunk_edges = Edges(in_chunk[:,0], in_chunk[:,1]) +between_chunk_edges = Edges(between_chunk[:,0], between_chunk[:,1]) +cross_chunk_edges = Edges(cross_chunk[:,0], cross_chunk[:,1]) + +edges_path = "" +chunk_coordinates = np.array([0,0,0]) + +edges_d = { + EDGE_TYPES.in_chunk: in_chunk_edges, + EDGE_TYPES.between_chunk: between_chunk_edges, + EDGE_TYPES.cross_chunk: cross_chunk_edges +} + +put_chunk_edges(edges_path, chunk_coordinates, edges_d, compression_level=22) +# file will be located at /edges_0_0_0.proto.zst + +# reading the file will simply return the previous dictionary +edges_d = get_chunk_edges(edges_path, [chunk_coordinates]) + +# notice the difference between chunk_coordinates parameter +# put_chunk_edges takes in coordinates for a single chunk +# get_chunk_edges takes in a list of chunk coordinates +``` diff --git a/ProofreadingDiagram.png b/docs/images/ProofreadingDiagram.png similarity index 100% rename from ProofreadingDiagram.png rename to docs/images/ProofreadingDiagram.png diff --git a/docs/images/edges/1.png b/docs/images/edges/1.png new file mode 100644 index 000000000..c66d811e7 Binary files /dev/null and b/docs/images/edges/1.png differ diff --git a/docs/images/edges/2.png b/docs/images/edges/2.png new file mode 100644 index 000000000..2384b452a Binary files /dev/null and b/docs/images/edges/2.png differ diff --git a/docs/images/edges/3.png b/docs/images/edges/3.png new file mode 100644 index 000000000..4dbd040a6 Binary files /dev/null and b/docs/images/edges/3.png differ diff --git a/docs/images/edges/4.png b/docs/images/edges/4.png new file mode 100644 index 000000000..62c84f837 Binary files /dev/null and b/docs/images/edges/4.png differ diff --git a/docs/meshing.md b/docs/meshing.md new file mode 100644 index 000000000..c0289a795 --- /dev/null +++ b/docs/meshing.md @@ -0,0 +1,65 @@ +## Meshing Overview + +This file is intended to explain and record how our meshing procedure works and why we designed it this way. +Meshing is done by starting at layer 2 in the chunk graph, and meshing each chunk within it. For each higher layer in the +graph, adjacent chunks in the layer below it are stitched together. + +### How meshes are created + +Meshes are created by giving 3D labeled voxel segmentation data, and passing that to ZMesh, a python library that runs marching cubes on the +data and produces meshes for the different objects in the data. ZMesh will produce a list of floats and one of unsigned integers. +The floats represent coordinates of the vertices of the mesh; the unsigned ints represent triangular faces of the mesh -- +indices in the list of floats. After running marching cubes, ZMesh will run a quadratic simplification algorithm on the mesh to simplify +the output. ZMesh is a python library written by Will Silversmith that wraps a C++ meshing library originally written by Alexander Zlateski. + +When we mesh each chunk, we download the segmentation data from CloudVolume for the chunk, as well as a 1 voxel overlap in the positive x, y, and z +directions (so that meshes across successive chunks that represent the same object will "stitch" together). This 3D subvolume is passed into ZMesh. + +![Two Adjacent Chunks](meshing_diagrams/TwoAdjacentChunks.png) + +Above: Two adjacent chunks (ABFE and EFGH) that will be passed into the MeshTask. EFCD is the 1 voxel overlap region when meshing ABFE, so the actual +region that will be meshed is ABCD. + +## The issue that using Draco introduces + +Draco quantizes the coordinates of the vertices in its mesh, meaning it creates a grid with equally spaced points over the mesh, +and changes the coordinates of each vertex to be the gridpoint that vertex is closest to. We can choose the size of the grid and its origin, +but it has to be a cube, and we have a limited amount of control over the amount of points in the grid. This is an issue because our chunks are +not necessarily cubes, and because we need our meshes that cross chunk boundaries to have identical vertices on each side of the boundary to +be able to stitch them together at the layer above. Because our chunks are not cubes, if we specify the bounding box of the grid to be the +smallest cube that contains the chunk we are meshing, some of this chunk's boundaries will not lie on the grid and because +the vertices on those boundaries will be snapped to the grid, there will be no way to stitch the affected meshes to their sides on the other +side of the chunk boundary. + +![Unstable Chunk Boundary](meshing_diagrams/UnstableChunkBoundary.png) + +Above is a diagram to visualize the problem. ABFE and EFGH are our PCG chunks, ABIJ and EFKL our respective draco bounding cubes. The problem with the +above setup is that if EF does not lie on the draco grid when meshing ABIJ, it will move, and it will end up in a different location than when meshing +EFKL. + +## Our resolution to this issue + +In order to resolve the above issue we need a way of choosing the draco bounding boxes such that any chunk boundary will be snapped +to the same location when meshing from either chunk on the sides of the boundary. We cannot avoid the boundaries moving but we can ensure +that given any pair of adjacent chunks, their shared boundary moves to the same location after meshing. + +To do this we create a global grid over the entire worldspace at each layer, where the lines in this grid will be exactly where the quantization +lines in the draco bounding box grids will be. We set the size of each draco cube to be a specific multiple of the length between each point in the +grid, and we make sure the origin of each draco cube is at some point in the grid. We do this because if we mesh two adjacent chunks, +where the chunks' draco cubes overlap and sit in this global grid, then the boundary these two chunks share will be moved to the same plane in the +grid when meshed from either chunk. + +![Stable Chunk Boundary](meshing_diagrams/StableChunkBoundary.png) + +The above diagram seems similar to the previous setup, but the key difference is that the length of a side of the new cube is 7 instead of 6. +Draco forces us to have 2^n grid points but that includes the beginning and end of the grid, meaning that if we want to have our grid points be at +integer values our draco cube side length needs to be a multiple of 2^n - 1, in this case n = 3 so 7. This n represents the amount of bits Draco uses +internally to encode a coordinate of a vertex. + +## Details of how the draco parameters are chosen + +The details for how the draco parameters are chosen are in the function get_draco_encoding_settings_for_chunk_exact_boundary in meshgen.py. +Here's a brief summary: given a layer in the chunk graph and a mip level of the segmentation data, we find the longest length of a chunk in that layer +in nm and the shortest component of a voxel in that mip level in nm. Preliminarly select that longest length to be the side length of the draco cube, +then select the smallest n such that the draco bin size is less than that shortest component over the square root of 2 (because of how marching cubes works). +Then expand the draco cube size until the bin size is an integer. Now the bins and origin are all at integers, making stitching possible. diff --git a/docs/segmentation_preprocessing.md b/docs/segmentation_preprocessing.md new file mode 100644 index 000000000..3fb1bf59b --- /dev/null +++ b/docs/segmentation_preprocessing.md @@ -0,0 +1,41 @@ +# Preprocessing + +The goal of preprocessing a over-segmentation is to store the voxel and edge data in such a way that a ChunkedGraph can be created from it. + +## Tools +The ChunkedGraph makes heavy use of [CloudVolume](https://github.com/seung-lab/cloud-volume/) for interfacing with volumetric data and [CloudFiles](https://github.com/seung-lab/cloud-files) for all other data on Google Cloud. It is recommended to use these tools during preprocessing to ensure compatibility. + +## Data formats + +[`precomputed`](https://github.com/google/neuroglancer/tree/master/src/neuroglancer/datasource/precomputed) introduced by [neuroglancer](https://github.com/google/neuroglancer) is a commonly used data format for volumetric data and is supported by CloudVolume. The supervoxel segmentation for the ChunkedGraph should ideally be stored using precomputed. Further, we recommend storing the supervoxel segmentation on Google Cloud in the same zone the ChunkedGraph server will be deployed in to reduce latency and avoid egress costs (Google Cloud does not charge for within-zone egress). + +The ChunkedGraph's format is called [`graphene`](https://github.com/seung-lab/cloud-volume/wiki/Graphene) which builds on precomputed. It combines the supervoxel segmentation with an agglomeration graph to provide a dynamic segmentation that can be edited. + +## Segmentation IDs + +Graphene follows a strict ID nomenclature for all supervoxels and other nodes in the ChunkedGraph hierarchy. In practice that means that most segmentations need to be rewritten to follow this nomenclature. Each ID stores data about its location in space (Chunk coord) and in the ChunkedGraph hierarchy (Layer id): + +![](https://user-images.githubusercontent.com/2517065/77118406-7dbd5a00-6a0a-11ea-96bb-003b83beb866.png) + +The number of bits for the layer id is fixed to 8 bits. Supervoxels are on layer 1. The number of bits for the chunk coordinates decreases with each layer as there are fewer chunks in higher layers but more segments per chunk. + +Before preprocessing one's segmentation one must determine the bounding box of the segmentation and then the number of bits needed to represent the chunk coordinates. CloudVolume supports bounding boxes that start at arbitrary locations in space (see also `cv.bounds` for a cloudvolume instance of a precomputed segmentation). Once ingested, the bounding box cannot be changed anymore. It is advantageous to keep the bounding box as small as possible. + +After determining the bounding box one can extract the necessary number of bits for the chunk coords by counting the number of chunks in each dimension and calculating how many bits are required to represent the ids in the largest dimension. The number of bits for all dimensions are identical. + +Lastly, each supervoxel within a chunk is assigned an ID that is unique among all supervoxels _within_ the same chunk. This ID can be assigned at random but we recommend assigning IDs from a sequential ID space starting at 0. + +## Storing supervoxel edges and components + +There are three types of edges: +1. `in_chunk`: edges between supervoxels within a chunk +2. `cross_chunk`: edges between parts of "the same" supervoxel in the unchunked segmentation that has been split across chunk boundary +3. `between_chunk`: edges between supervoxels across chunks + +Every pair of touching supervoxels has an edge between them. All edges are stored using [protobuf](https://github.com/seung-lab/PyChunkedGraph/blob/pcgv2/pychunkedgraph/io/protobuf/chunkEdges.proto). During ingest only edges of type 2. and 3. are copied into BigTable, whereas edges of type 1. are always read from storage to reduce cost. Similar to the supervoxel segmentation, we recommed storing these on GCloud in the same zone the ChunkedGraph server will be deployed in to reduce latency. + +To denote which edges form a connected component within a chunk, a component mapping needs to be created. This mapping is only used during ingest. + +More details on how to create these protobuf files can be found [here](https://github.com/seung-lab/PyChunkedGraph/blob/pcgv2/docs/storage.md). + + diff --git a/ingest/README.md b/ingest/README.md new file mode 100644 index 000000000..95e1daa49 --- /dev/null +++ b/ingest/README.md @@ -0,0 +1 @@ +This section has moved to [CAVEpipelines](https://github.com/seung-lab/CAVEpipelines). \ No newline at end of file diff --git a/kube-dev/1.redis.yml b/kube-dev/1.redis.yml deleted file mode 100644 index 4e9db473a..000000000 --- a/kube-dev/1.redis.yml +++ /dev/null @@ -1,35 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - name: redis - labels: - app: pcg-redis -spec: - type: NodePort - ports: - - port: 6379 - targetPort: 6379 - selector: - app: pcg-redis ---- -apiVersion: v1 -kind: Pod -metadata: - name: redis - labels: - app: pcg-redis -spec: - restartPolicy: Never - containers: - - name: redis - image: redis:5.0.4-alpine - imagePullPolicy: Always - resources: - requests: - memory: 10Gi - command: ["redis-server"] - args: ["--requirepass", "dev", "--save", "", "--appendonly", "no"] - ports: - - containerPort: 6379 - nodeSelector: - nodetype: redis-server \ No newline at end of file diff --git a/kube-dev/2.pcg.yml b/kube-dev/2.pcg.yml deleted file mode 100644 index 78a0549b0..000000000 --- a/kube-dev/2.pcg.yml +++ /dev/null @@ -1,67 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - name: pcg-svc - labels: - app: pcg-master -spec: - type: LoadBalancer - ports: - - port: 4000 - targetPort: 4000 - name: pcg - - port: 9181 - targetPort: 9181 - name: rq-dashboard - selector: - app: pcg-master ---- -apiVersion: v1 -kind: Pod -metadata: - name: pcg-master - labels: - app: pcg-master -spec: - restartPolicy: Never # for development - volumes: - - name: google-secret - secret: - secretName: google-secret - - name: microns-seunglab-google-secret - secret: - secretName: microns-seunglab-google-secret - - name: seunglab2-google-secret - secret: - secretName: seunglab2-google-secret - containers: - - name: pcg - image: gcr.io/neuromancer-seung-import/pychunkedgraph:akhilesh-fafb-mesh-worker - imagePullPolicy: Always - resources: - requests: - memory: 2Gi - env: - - name: APP_SETTINGS - value: pychunkedgraph.app.config.DeploymentWithRedisConfig - - name: FLASK_APP - value: run_dev_cli.py - - name: REDIS_PASSWORD - value: dev # for development - ports: - - containerPort: 4000 - volumeMounts: - - name: google-secret - mountPath: /root/.cloudvolume/secrets/google-secret.json - subPath: google-secret.json - readOnly: true - - name: microns-seunglab-google-secret - mountPath: /root/.cloudvolume/secrets/microns-seunglab-google-secret.json - subPath: microns-seunglab-google-secret.json - readOnly: true - - name: seunglab2-google-secret - mountPath: /root/.cloudvolume/secrets/seunglab2-google-secret.json - subPath: seunglab2-google-secret.json - readOnly: true - nodeSelector: - nodetype: pcg-master \ No newline at end of file diff --git a/kube-dev/3.pcg-worker.yml b/kube-dev/3.pcg-worker.yml deleted file mode 100644 index c69df0481..000000000 --- a/kube-dev/3.pcg-worker.yml +++ /dev/null @@ -1,57 +0,0 @@ -apiVersion: apps/v1 -kind: Deployment -metadata: - name: pcg-worker - labels: - app: pcg-worker -spec: - selector: - matchLabels: - app: pcg-worker - replicas: 1 - template: - metadata: - labels: - app: pcg-worker - spec: - dnsPolicy: Default - volumes: - - name: google-secret - secret: - secretName: google-secret - - name: microns-seunglab-google-secret - secret: - secretName: microns-seunglab-google-secret - - name: seunglab2-google-secret - secret: - secretName: seunglab2-google-secret - containers: - - name: pcg-worker - image: gcr.io/neuromancer-seung-import/pychunkedgraph:akhilesh-fafb-mesh-worker - imagePullPolicy: Always - resources: - requests: - memory: 3Gi - env: - - name: APP_SETTINGS - value: pychunkedgraph.app.config.DeploymentWithRedisConfig - - name: FLASK_APP - value: run_dev_cli.py - - name: REDIS_PASSWORD - value: dev # for development - volumeMounts: - - name: google-secret - mountPath: /root/.cloudvolume/secrets/google-secret.json - subPath: google-secret.json - readOnly: true - - name: microns-seunglab-google-secret - mountPath: /root/.cloudvolume/secrets/microns-seunglab-google-secret.json - subPath: microns-seunglab-google-secret.json - readOnly: true - - name: seunglab2-google-secret - mountPath: /root/.cloudvolume/secrets/seunglab2-google-secret.json - subPath: seunglab2-google-secret.json - readOnly: true - command: ["rq"] - args: ["worker", "-c", "rq_workers.test_worker"] - \ No newline at end of file diff --git a/kube-dev/4.hpa.yml b/kube-dev/4.hpa.yml deleted file mode 100644 index 3d19210ac..000000000 --- a/kube-dev/4.hpa.yml +++ /dev/null @@ -1,15 +0,0 @@ -apiVersion: autoscaling/v1 -kind: HorizontalPodAutoscaler -metadata: - name: pcg-hpa - namespace: default - labels: - app: pcg-worker -spec: - scaleTargetRef: - apiVersion: apps/v1 - kind: Deployment - name: pcg-worker - minReplicas: 1 - maxReplicas: 8 - targetCPUUtilizationPercentage: 80 \ No newline at end of file diff --git a/override/gcloud b/override/gcloud new file mode 100755 index 000000000..882fbc73e --- /dev/null +++ b/override/gcloud @@ -0,0 +1,55 @@ +#!/app/venv/bin/python + +import argparse +import os +import subprocess +import re + +CONFIG_ENV = os.path.expanduser("~/.config/gcloud/emulators/bigtable/env.yaml") + +def env_init(args): + try: + with open(CONFIG_ENV, "r") as f: + hostport = re.findall(r"BIGTABLE_EMULATOR_HOST:\s+([:\w]+:\d+)", f.read())[0] + print(f"export BIGTABLE_EMULATOR={hostport}") + except Exception: + print(f"export BIGTABLE_EMULATOR=localhost:9000") + +def start(args): + os.makedirs(os.path.dirname(CONFIG_ENV), exist_ok=True) + with open(CONFIG_ENV, "w") as f: + f.write(f"---\nBIGTABLE_EMULATOR_HOST: {args.host_port}") + + host, port = args.host_port.rsplit(':', 1) + subprocess.Popen(["cbtemulator", "-host", host, "-port", port], start_new_session=True) + +def usage(args): + print("""This is not gcloud. Only supported commands are: + - gcloud beta emulators bigtable env-init + - gcloud beta emulators bigtable start [--host-port localhost:9000]""") + +if __name__ == '__main__': + parser_gcloud = argparse.ArgumentParser(prog='gcloud') + parser_gcloud.set_defaults(func=usage) + subparser_gcloud = parser_gcloud.add_subparsers() + + parser_beta = subparser_gcloud.add_parser('beta') + subparser_beta = parser_beta.add_subparsers() + + parser_emulators = subparser_beta.add_parser('emulators') + subparser_emulators = parser_emulators.add_subparsers() + + parser_bigtable = subparser_emulators.add_parser('bigtable') + subparser_bigtable = parser_bigtable.add_subparsers() + + parser_env_init = subparser_bigtable.add_parser('env-init') + parser_env_init.set_defaults(func=env_init) + + parser_start = subparser_bigtable.add_parser('start') + parser_start.add_argument('--host-port', default='localhost:9000') + parser_start.set_defaults(func=start) + + args = parser_gcloud.parse_args() + args.func(args) + + diff --git a/override/jupyter_notebook_config.py b/override/jupyter_notebook_config.py deleted file mode 100644 index ecece8ce8..000000000 --- a/override/jupyter_notebook_config.py +++ /dev/null @@ -1,765 +0,0 @@ -# Configuration file for jupyter-notebook. - -#------------------------------------------------------------------------------ -# Application(SingletonConfigurable) configuration -#------------------------------------------------------------------------------ - -## This is an application. - -## The date format used by logging formatters for %(asctime)s -#c.Application.log_datefmt = '%Y-%m-%d %H:%M:%S' - -## The Logging format template -#c.Application.log_format = '[%(name)s]%(highlevel)s %(message)s' - -## Set the log level by value or name. -#c.Application.log_level = 30 - -#------------------------------------------------------------------------------ -# JupyterApp(Application) configuration -#------------------------------------------------------------------------------ - -## Base class for Jupyter applications - -## Answer yes to any prompts. -#c.JupyterApp.answer_yes = False - -## Full path of a config file. -#c.JupyterApp.config_file = '' - -## Specify a config file to load. -#c.JupyterApp.config_file_name = '' - -## Generate default config file. -#c.JupyterApp.generate_config = False - -#------------------------------------------------------------------------------ -# NotebookApp(JupyterApp) configuration -#------------------------------------------------------------------------------ - -## Set the Access-Control-Allow-Credentials: true header -#c.NotebookApp.allow_credentials = False - -## Set the Access-Control-Allow-Origin header -# -# Use '*' to allow any origin to access your server. -# -# Takes precedence over allow_origin_pat. -c.NotebookApp.allow_origin = '*' - -## Use a regular expression for the Access-Control-Allow-Origin header -# -# Requests from an origin matching the expression will get replies with: -# -# Access-Control-Allow-Origin: origin -# -# where `origin` is the origin of the request. -# -# Ignored if allow_origin is set. -#c.NotebookApp.allow_origin_pat = '' - -## Allow password to be changed at login for the notebook server. -# -# While loggin in with a token, the notebook server UI will give the opportunity -# to the user to enter a new password at the same time that will replace the -# token login mechanism. -# -# This can be set to false to prevent changing password from the UI/API. -#c.NotebookApp.allow_password_change = True - -## Allow requests where the Host header doesn't point to a local server -# -# By default, requests get a 403 forbidden response if the 'Host' header shows -# that the browser thinks it's on a non-local domain. Setting this option to -# True disables this check. -# -# This protects against 'DNS rebinding' attacks, where a remote web server -# serves you a page and then changes its DNS to send later requests to a local -# IP, bypassing same-origin checks. -# -# Local IP addresses (such as 127.0.0.1 and ::1) are allowed as local, along -# with hostnames configured in local_hostnames. -#c.NotebookApp.allow_remote_access = False - -## Whether to allow the user to run the notebook as root. -#c.NotebookApp.allow_root = False - -## DEPRECATED use base_url -#c.NotebookApp.base_project_url = '/' - -## The base URL for the notebook server. -# -# Leading and trailing slashes can be omitted, and will automatically be added. -#c.NotebookApp.base_url = '/' - -## Specify what command to use to invoke a web browser when opening the notebook. -# If not specified, the default browser will be determined by the `webbrowser` -# standard library module, which allows setting of the BROWSER environment -# variable to override it. -#c.NotebookApp.browser = '' - -## The full path to an SSL/TLS certificate file. -#c.NotebookApp.certfile = '' - -## The full path to a certificate authority certificate for SSL/TLS client -# authentication. -#c.NotebookApp.client_ca = '' - -## The config manager class to use -#c.NotebookApp.config_manager_class = 'notebook.services.config.manager.ConfigManager' - -## The notebook manager class to use. -#c.NotebookApp.contents_manager_class = 'notebook.services.contents.largefilemanager.LargeFileManager' - -## Extra keyword arguments to pass to `set_secure_cookie`. See tornado's -# set_secure_cookie docs for details. -#c.NotebookApp.cookie_options = {} - -## The random bytes used to secure cookies. By default this is a new random -# number every time you start the Notebook. Set it to a value in a config file -# to enable logins to persist across server sessions. -# -# Note: Cookie secrets should be kept private, do not share config files with -# cookie_secret stored in plaintext (you can read the value from a file). -#c.NotebookApp.cookie_secret = b'' - -## The file where the cookie secret is stored. -#c.NotebookApp.cookie_secret_file = '' - -## Override URL shown to users. -# -# Replace actual URL, including protocol, address, port and base URL, with the -# given value when displaying URL to the users. Do not change the actual -# connection URL. If authentication token is enabled, the token is added to the -# custom URL automatically. -# -# This option is intended to be used when the URL to display to the user cannot -# be determined reliably by the Jupyter notebook server (proxified or -# containerized setups for example). -#c.NotebookApp.custom_display_url = '' - -## The default URL to redirect to from `/` -#c.NotebookApp.default_url = '/tree' - -## Disable cross-site-request-forgery protection -# -# Jupyter notebook 4.3.1 introduces protection from cross-site request -# forgeries, requiring API requests to either: -# -# - originate from pages served by this server (validated with XSRF cookie and -# token), or - authenticate with a token -# -# Some anonymous compute resources still desire the ability to run code, -# completely without authentication. These services can disable all -# authentication and security checks, with the full knowledge of what that -# implies. -#c.NotebookApp.disable_check_xsrf = False - -## Whether to enable MathJax for typesetting math/TeX -# -# MathJax is the javascript library Jupyter uses to render math/LaTeX. It is -# very large, so you may want to disable it if you have a slow internet -# connection, or for offline use of the notebook. -# -# When disabled, equations etc. will appear as their untransformed TeX source. -#c.NotebookApp.enable_mathjax = True - -## extra paths to look for Javascript notebook extensions -#c.NotebookApp.extra_nbextensions_path = [] - -## handlers that should be loaded at higher priority than the default services -#c.NotebookApp.extra_services = [] - -## Extra paths to search for serving static files. -# -# This allows adding javascript/css to be available from the notebook server -# machine, or overriding individual files in the IPython -#c.NotebookApp.extra_static_paths = [] - -## Extra paths to search for serving jinja templates. -# -# Can be used to override templates from notebook.templates. -#c.NotebookApp.extra_template_paths = [] - -## -#c.NotebookApp.file_to_run = '' - -## Extra keyword arguments to pass to `get_secure_cookie`. See tornado's -# get_secure_cookie docs for details. -#c.NotebookApp.get_secure_cookie_kwargs = {} - -## Deprecated: Use minified JS file or not, mainly use during dev to avoid JS -# recompilation -#c.NotebookApp.ignore_minified_js = False - -## (bytes/sec) Maximum rate at which stream output can be sent on iopub before -# they are limited. -#c.NotebookApp.iopub_data_rate_limit = 1000000 - -## (msgs/sec) Maximum rate at which messages can be sent on iopub before they are -# limited. -#c.NotebookApp.iopub_msg_rate_limit = 1000 - -## The IP address the notebook server will listen on. -c.NotebookApp.ip = '0.0.0.0' - -## Supply extra arguments that will be passed to Jinja environment. -#c.NotebookApp.jinja_environment_options = {} - -## Extra variables to supply to jinja templates when rendering. -#c.NotebookApp.jinja_template_vars = {} - -## The kernel manager class to use. -#c.NotebookApp.kernel_manager_class = 'notebook.services.kernels.kernelmanager.MappingKernelManager' - -## The kernel spec manager class to use. Should be a subclass of -# `jupyter_client.kernelspec.KernelSpecManager`. -# -# The Api of KernelSpecManager is provisional and might change without warning -# between this version of Jupyter and the next stable one. -#c.NotebookApp.kernel_spec_manager_class = 'jupyter_client.kernelspec.KernelSpecManager' - -## The full path to a private key file for usage with SSL/TLS. -#c.NotebookApp.keyfile = '' - -## Hostnames to allow as local when allow_remote_access is False. -# -# Local IP addresses (such as 127.0.0.1 and ::1) are automatically accepted as -# local as well. -#c.NotebookApp.local_hostnames = ['localhost'] - -## The login handler class to use. -#c.NotebookApp.login_handler_class = 'notebook.auth.login.LoginHandler' - -## The logout handler class to use. -#c.NotebookApp.logout_handler_class = 'notebook.auth.logout.LogoutHandler' - -## The MathJax.js configuration file that is to be used. -#c.NotebookApp.mathjax_config = 'TeX-AMS-MML_HTMLorMML-full,Safe' - -## A custom url for MathJax.js. Should be in the form of a case-sensitive url to -# MathJax, for example: /static/components/MathJax/MathJax.js -#c.NotebookApp.mathjax_url = '' - -## Sets the maximum allowed size of the client request body, specified in the -# Content-Length request header field. If the size in a request exceeds the -# configured value, a malformed HTTP message is returned to the client. -# -# Note: max_body_size is applied even in streaming mode. -#c.NotebookApp.max_body_size = 536870912 - -## Gets or sets the maximum amount of memory, in bytes, that is allocated for -# use by the buffer manager. -#c.NotebookApp.max_buffer_size = 536870912 - -## Dict of Python modules to load as notebook server extensions.Entry values can -# be used to enable and disable the loading ofthe extensions. The extensions -# will be loaded in alphabetical order. -#c.NotebookApp.nbserver_extensions = {} - -## The directory to use for notebooks and kernels. -#c.NotebookApp.notebook_dir = '' - -## Whether to open in a browser after starting. The specific browser used is -# platform dependent and determined by the python standard library `webbrowser` -# module, unless it is overridden using the --browser (NotebookApp.browser) -# configuration option. -#c.NotebookApp.open_browser = True - -## Hashed password to use for web authentication. -# -# To generate, type in a python/IPython shell: -# -# from notebook.auth import passwd; passwd() -# -# The string should be of the form type:salt:hashed-password. -c.NotebookApp.password = 'sha1:99c4a251cbc1:7b1916ef1147bd40891739b56192177a5f00e468' - -## Forces users to use a password for the Notebook server. This is useful in a -# multi user environment, for instance when everybody in the LAN can access each -# other's machine through ssh. -# -# In such a case, server the notebook server on localhost is not secure since -# any user can connect to the notebook server via ssh. -#c.NotebookApp.password_required = False - -## The port the notebook server will listen on. -#c.NotebookApp.port = 8888 - -## The number of additional ports to try if the specified port is not available. -#c.NotebookApp.port_retries = 50 - -## DISABLED: use %pylab or %matplotlib in the notebook to enable matplotlib. -#c.NotebookApp.pylab = 'disabled' - -## If True, display a button in the dashboard to quit (shutdown the notebook -# server). -#c.NotebookApp.quit_button = True - -## (sec) Time window used to check the message and data rate limits. -#c.NotebookApp.rate_limit_window = 3 - -## Reraise exceptions encountered loading server extensions? -#c.NotebookApp.reraise_server_extension_failures = False - -## DEPRECATED use the nbserver_extensions dict instead -#c.NotebookApp.server_extensions = [] - -## The session manager class to use. -#c.NotebookApp.session_manager_class = 'notebook.services.sessions.sessionmanager.SessionManager' - -## Shut down the server after N seconds with no kernels or terminals running and -# no activity. This can be used together with culling idle kernels -# (MappingKernelManager.cull_idle_timeout) to shutdown the notebook server when -# it's not in use. This is not precisely timed: it may shut down up to a minute -# later. 0 (the default) disables this automatic shutdown. -#c.NotebookApp.shutdown_no_activity_timeout = 0 - -## Supply SSL options for the tornado HTTPServer. See the tornado docs for -# details. -#c.NotebookApp.ssl_options = {} - -## Supply overrides for terminado. Currently only supports "shell_command". -#c.NotebookApp.terminado_settings = {} - -## Set to False to disable terminals. -# -# This does *not* make the notebook server more secure by itself. Anything the -# user can in a terminal, they can also do in a notebook. -# -# Terminals may also be automatically disabled if the terminado package is not -# available. -#c.NotebookApp.terminals_enabled = True - -## Token used for authenticating first-time connections to the server. -# -# When no password is enabled, the default is to generate a new, random token. -# -# Setting to an empty string disables authentication altogether, which is NOT -# RECOMMENDED. -#c.NotebookApp.token = '' - -## Supply overrides for the tornado.web.Application that the Jupyter notebook -# uses. -#c.NotebookApp.tornado_settings = {} - -## Whether to trust or not X-Scheme/X-Forwarded-Proto and X-Real-Ip/X-Forwarded- -# For headerssent by the upstream reverse proxy. Necessary if the proxy handles -# SSL -#c.NotebookApp.trust_xheaders = False - -## DEPRECATED, use tornado_settings -#c.NotebookApp.webapp_settings = {} - -## Specify Where to open the notebook on startup. This is the `new` argument -# passed to the standard library method `webbrowser.open`. The behaviour is not -# guaranteed, but depends on browser support. Valid values are: -# -# - 2 opens a new tab, -# - 1 opens a new window, -# - 0 opens in an existing window. -# -# See the `webbrowser.open` documentation for details. -#c.NotebookApp.webbrowser_open_new = 2 - -## Set the tornado compression options for websocket connections. -# -# This value will be returned from -# :meth:`WebSocketHandler.get_compression_options`. None (default) will disable -# compression. A dict (even an empty one) will enable compression. -# -# See the tornado docs for WebSocketHandler.get_compression_options for details. -#c.NotebookApp.websocket_compression_options = None - -## The base URL for websockets, if it differs from the HTTP server (hint: it -# almost certainly doesn't). -# -# Should be in the form of an HTTP origin: ws[s]://hostname[:port] -#c.NotebookApp.websocket_url = '' - -#------------------------------------------------------------------------------ -# ConnectionFileMixin(LoggingConfigurable) configuration -#------------------------------------------------------------------------------ - -## Mixin for configurable classes that work with connection files - -## JSON file in which to store connection info [default: kernel-.json] -# -# This file will contain the IP, ports, and authentication key needed to connect -# clients to this kernel. By default, this file will be created in the security -# dir of the current profile, but can be specified by absolute path. -#c.ConnectionFileMixin.connection_file = '' - -## set the control (ROUTER) port [default: random] -#c.ConnectionFileMixin.control_port = 0 - -## set the heartbeat port [default: random] -#c.ConnectionFileMixin.hb_port = 0 - -## set the iopub (PUB) port [default: random] -#c.ConnectionFileMixin.iopub_port = 0 - -## Set the kernel's IP address [default localhost]. If the IP address is -# something other than localhost, then Consoles on other machines will be able -# to connect to the Kernel, so be careful! -#c.ConnectionFileMixin.ip = '' - -## set the shell (ROUTER) port [default: random] -#c.ConnectionFileMixin.shell_port = 0 - -## set the stdin (ROUTER) port [default: random] -#c.ConnectionFileMixin.stdin_port = 0 - -## -#c.ConnectionFileMixin.transport = 'tcp' - -#------------------------------------------------------------------------------ -# KernelManager(ConnectionFileMixin) configuration -#------------------------------------------------------------------------------ - -## Manages a single kernel in a subprocess on this host. -# -# This version starts kernels with Popen. - -## Should we autorestart the kernel if it dies. -#c.KernelManager.autorestart = True - -## DEPRECATED: Use kernel_name instead. -# -# The Popen Command to launch the kernel. Override this if you have a custom -# kernel. If kernel_cmd is specified in a configuration file, Jupyter does not -# pass any arguments to the kernel, because it cannot make any assumptions about -# the arguments that the kernel understands. In particular, this means that the -# kernel does not receive the option --debug if it given on the Jupyter command -# line. -#c.KernelManager.kernel_cmd = [] - -## Time to wait for a kernel to terminate before killing it, in seconds. -#c.KernelManager.shutdown_wait_time = 5.0 - -#------------------------------------------------------------------------------ -# Session(Configurable) configuration -#------------------------------------------------------------------------------ - -## Object for handling serialization and sending of messages. -# -# The Session object handles building messages and sending them with ZMQ sockets -# or ZMQStream objects. Objects can communicate with each other over the -# network via Session objects, and only need to work with the dict-based IPython -# message spec. The Session will handle serialization/deserialization, security, -# and metadata. -# -# Sessions support configurable serialization via packer/unpacker traits, and -# signing with HMAC digests via the key/keyfile traits. -# -# Parameters ---------- -# -# debug : bool -# whether to trigger extra debugging statements -# packer/unpacker : str : 'json', 'pickle' or import_string -# importstrings for methods to serialize message parts. If just -# 'json' or 'pickle', predefined JSON and pickle packers will be used. -# Otherwise, the entire importstring must be used. -# -# The functions must accept at least valid JSON input, and output *bytes*. -# -# For example, to use msgpack: -# packer = 'msgpack.packb', unpacker='msgpack.unpackb' -# pack/unpack : callables -# You can also set the pack/unpack callables for serialization directly. -# session : bytes -# the ID of this Session object. The default is to generate a new UUID. -# username : unicode -# username added to message headers. The default is to ask the OS. -# key : bytes -# The key used to initialize an HMAC signature. If unset, messages -# will not be signed or checked. -# keyfile : filepath -# The file containing a key. If this is set, `key` will be initialized -# to the contents of the file. - -## Threshold (in bytes) beyond which an object's buffer should be extracted to -# avoid pickling. -#c.Session.buffer_threshold = 1024 - -## Whether to check PID to protect against calls after fork. -# -# This check can be disabled if fork-safety is handled elsewhere. -#c.Session.check_pid = True - -## Threshold (in bytes) beyond which a buffer should be sent without copying. -#c.Session.copy_threshold = 65536 - -## Debug output in the Session -#c.Session.debug = False - -## The maximum number of digests to remember. -# -# The digest history will be culled when it exceeds this value. -#c.Session.digest_history_size = 65536 - -## The maximum number of items for a container to be introspected for custom -# serialization. Containers larger than this are pickled outright. -#c.Session.item_threshold = 64 - -## execution key, for signing messages. -#c.Session.key = b'' - -## path to file containing execution key. -#c.Session.keyfile = '' - -## Metadata dictionary, which serves as the default top-level metadata dict for -# each message. -#c.Session.metadata = {} - -## The name of the packer for serializing messages. Should be one of 'json', -# 'pickle', or an import name for a custom callable serializer. -#c.Session.packer = 'json' - -## The UUID identifying this session. -#c.Session.session = '' - -## The digest scheme used to construct the message signatures. Must have the form -# 'hmac-HASH'. -#c.Session.signature_scheme = 'hmac-sha256' - -## The name of the unpacker for unserializing messages. Only used with custom -# functions for `packer`. -#c.Session.unpacker = 'json' - -## Username for the Session. Default is your system username. -#c.Session.username = 'username' - -#------------------------------------------------------------------------------ -# MultiKernelManager(LoggingConfigurable) configuration -#------------------------------------------------------------------------------ - -## A class for managing multiple kernels. - -## The name of the default kernel to start -#c.MultiKernelManager.default_kernel_name = 'python3' - -## The kernel manager class. This is configurable to allow subclassing of the -# KernelManager for customized behavior. -#c.MultiKernelManager.kernel_manager_class = 'jupyter_client.ioloop.IOLoopKernelManager' - -#------------------------------------------------------------------------------ -# MappingKernelManager(MultiKernelManager) configuration -#------------------------------------------------------------------------------ - -## A KernelManager that handles notebook mapping and HTTP error handling - -## Whether messages from kernels whose frontends have disconnected should be -# buffered in-memory. -# -# When True (default), messages are buffered and replayed on reconnect, avoiding -# lost messages due to interrupted connectivity. -# -# Disable if long-running kernels will produce too much output while no -# frontends are connected. -#c.MappingKernelManager.buffer_offline_messages = True - -## Whether to consider culling kernels which are busy. Only effective if -# cull_idle_timeout > 0. -#c.MappingKernelManager.cull_busy = False - -## Whether to consider culling kernels which have one or more connections. Only -# effective if cull_idle_timeout > 0. -#c.MappingKernelManager.cull_connected = False - -## Timeout (in seconds) after which a kernel is considered idle and ready to be -# culled. Values of 0 or lower disable culling. Very short timeouts may result -# in kernels being culled for users with poor network connections. -#c.MappingKernelManager.cull_idle_timeout = 0 - -## The interval (in seconds) on which to check for idle kernels exceeding the -# cull timeout value. -#c.MappingKernelManager.cull_interval = 300 - -## Timeout for giving up on a kernel (in seconds). -# -# On starting and restarting kernels, we check whether the kernel is running and -# responsive by sending kernel_info_requests. This sets the timeout in seconds -# for how long the kernel can take before being presumed dead. This affects the -# MappingKernelManager (which handles kernel restarts) and the -# ZMQChannelsHandler (which handles the startup). -#c.MappingKernelManager.kernel_info_timeout = 60 - -## -#c.MappingKernelManager.root_dir = '' - -#------------------------------------------------------------------------------ -# ContentsManager(LoggingConfigurable) configuration -#------------------------------------------------------------------------------ - -## Base class for serving files and directories. -# -# This serves any text or binary file, as well as directories, with special -# handling for JSON notebook documents. -# -# Most APIs take a path argument, which is always an API-style unicode path, and -# always refers to a directory. -# -# - unicode, not url-escaped -# - '/'-separated -# - leading and trailing '/' will be stripped -# - if unspecified, path defaults to '', -# indicating the root path. - -## Allow access to hidden files -#c.ContentsManager.allow_hidden = False - -## -#c.ContentsManager.checkpoints = None - -## -#c.ContentsManager.checkpoints_class = 'notebook.services.contents.checkpoints.Checkpoints' - -## -#c.ContentsManager.checkpoints_kwargs = {} - -## handler class to use when serving raw file requests. -# -# Default is a fallback that talks to the ContentsManager API, which may be -# inefficient, especially for large files. -# -# Local files-based ContentsManagers can use a StaticFileHandler subclass, which -# will be much more efficient. -# -# Access to these files should be Authenticated. -#c.ContentsManager.files_handler_class = 'notebook.files.handlers.FilesHandler' - -## Extra parameters to pass to files_handler_class. -# -# For example, StaticFileHandlers generally expect a `path` argument specifying -# the root directory from which to serve files. -#c.ContentsManager.files_handler_params = {} - -## Glob patterns to hide in file and directory listings. -#c.ContentsManager.hide_globs = ['__pycache__', '*.pyc', '*.pyo', '.DS_Store', '*.so', '*.dylib', '*~'] - -## Python callable or importstring thereof -# -# To be called on a contents model prior to save. -# -# This can be used to process the structure, such as removing notebook outputs -# or other side effects that should not be saved. -# -# It will be called as (all arguments passed by keyword):: -# -# hook(path=path, model=model, contents_manager=self) -# -# - model: the model to be saved. Includes file contents. -# Modifying this dict will affect the file that is stored. -# - path: the API path of the save destination -# - contents_manager: this ContentsManager instance -#c.ContentsManager.pre_save_hook = None - -## -#c.ContentsManager.root_dir = '/' - -## The base name used when creating untitled directories. -#c.ContentsManager.untitled_directory = 'Untitled Folder' - -## The base name used when creating untitled files. -#c.ContentsManager.untitled_file = 'untitled' - -## The base name used when creating untitled notebooks. -#c.ContentsManager.untitled_notebook = 'Untitled' - -#------------------------------------------------------------------------------ -# FileManagerMixin(Configurable) configuration -#------------------------------------------------------------------------------ - -## Mixin for ContentsAPI classes that interact with the filesystem. -# -# Provides facilities for reading, writing, and copying both notebooks and -# generic files. -# -# Shared by FileContentsManager and FileCheckpoints. -# -# Note ---- Classes using this mixin must provide the following attributes: -# -# root_dir : unicode -# A directory against against which API-style paths are to be resolved. -# -# log : logging.Logger - -## By default notebooks are saved on disk on a temporary file and then if -# succefully written, it replaces the old ones. This procedure, namely -# 'atomic_writing', causes some bugs on file system whitout operation order -# enforcement (like some networked fs). If set to False, the new notebook is -# written directly on the old one which could fail (eg: full filesystem or quota -# ) -#c.FileManagerMixin.use_atomic_writing = True - -#------------------------------------------------------------------------------ -# FileContentsManager(FileManagerMixin,ContentsManager) configuration -#------------------------------------------------------------------------------ - -## If True (default), deleting files will send them to the platform's -# trash/recycle bin, where they can be recovered. If False, deleting files -# really deletes them. -#c.FileContentsManager.delete_to_trash = True - -## Python callable or importstring thereof -# -# to be called on the path of a file just saved. -# -# This can be used to process the file on disk, such as converting the notebook -# to a script or HTML via nbconvert. -# -# It will be called as (all arguments passed by keyword):: -# -# hook(os_path=os_path, model=model, contents_manager=instance) -# -# - path: the filesystem path to the file just written - model: the model -# representing the file - contents_manager: this ContentsManager instance -#c.FileContentsManager.post_save_hook = None - -## -#c.FileContentsManager.root_dir = '' - -## DEPRECATED, use post_save_hook. Will be removed in Notebook 5.0 -#c.FileContentsManager.save_script = False - -#------------------------------------------------------------------------------ -# NotebookNotary(LoggingConfigurable) configuration -#------------------------------------------------------------------------------ - -## A class for computing and verifying notebook signatures. - -## The hashing algorithm used to sign notebooks. -#c.NotebookNotary.algorithm = 'sha256' - -## The sqlite file in which to store notebook signatures. By default, this will -# be in your Jupyter data directory. You can set it to ':memory:' to disable -# sqlite writing to the filesystem. -#c.NotebookNotary.db_file = '' - -## The secret key with which notebooks are signed. -#c.NotebookNotary.secret = b'' - -## The file where the secret key is stored. -#c.NotebookNotary.secret_file = '' - -## A callable returning the storage backend for notebook signatures. The default -# uses an SQLite database. -#c.NotebookNotary.store_factory = traitlets.Undefined - -#------------------------------------------------------------------------------ -# KernelSpecManager(LoggingConfigurable) configuration -#------------------------------------------------------------------------------ - -## If there is no Python kernelspec registered and the IPython kernel is -# available, ensure it is added to the spec list. -#c.KernelSpecManager.ensure_native_kernel = True - -## The kernel spec class. This is configurable to allow subclassing of the -# KernelSpecManager for customized behavior. -#c.KernelSpecManager.kernel_spec_class = 'jupyter_client.kernelspec.KernelSpec' - -## Whitelist of allowed kernel names. -# -# By default, all installed kernels are allowed. -#c.KernelSpecManager.whitelist = set() diff --git a/override/supervisord.conf b/override/supervisord.conf index 37c371e07..46b7a42b5 100644 --- a/override/supervisord.conf +++ b/override/supervisord.conf @@ -3,7 +3,7 @@ nodaemon=true [program:uwsgi] -command=/usr/local/bin/uwsgi --ini /etc/uwsgi/uwsgi.ini --die-on-term --need-app +command=%(ENV_VIRTUAL_ENV)s/bin/uwsgi --ini /etc/uwsgi/uwsgi.ini --die-on-term --need-app --stats 127.0.0.1:9192 --stats-http stdout_logfile=/dev/stdout stdout_logfile_maxbytes=0 stderr_logfile=/dev/stderr diff --git a/pychunkedgraph/__init__.py b/pychunkedgraph/__init__.py index f221d99f1..6e9a24925 100644 --- a/pychunkedgraph/__init__.py +++ b/pychunkedgraph/__init__.py @@ -1 +1 @@ -__version__ = 'fafb.1.21' \ No newline at end of file +__version__ = "2.19.2" diff --git a/pychunkedgraph/admin/table_prep.py b/pychunkedgraph/admin/table_prep.py deleted file mode 100644 index b8e3c79c5..000000000 --- a/pychunkedgraph/admin/table_prep.py +++ /dev/null @@ -1,50 +0,0 @@ - - -from pychunkedgraph.exporting import export - - -def _log_type(log_entry): - if "removed_edges" in log_entry: - return "split" - else: - return "merge" - -def apply_log(cg, log): - assert cg.table_id != "pinky100_sv11" - - last_operation_id = -1 - for operation_id in log.keys(): - assert last_operation_id < int(operation_id) - - log_entry = log[operation_id] - - print(log_entry) - - if _log_type(log_entry) == "merge": - print("MERGE") - if len(log_entry["added_edges"]) == 0: - affinities = None - else: - affinities = log_entry["added_edges"] - - cg.add_edges(user_id=log_entry["user"], - atomic_edges=log_entry["added_edges"], - affinities=affinities, - source_coord=log_entry["source_coords"], - sink_coord=log_entry["sink_coords"], - n_tries=60) - elif _log_type(log_entry) == "split": - print("SPLIT") - cg.remove_edges(user_id=log_entry["user"], - source_ids=log_entry["source_ids"], - sink_ids=log_entry["sink_ids"], - source_coords=log_entry["source_coords"], - sink_coords=log_entry["sink_coords"], - atomic_edges=log_entry["removed_edges"], - mincut=False, - bb_offset=log_entry["bb_offset"], - n_tries=20) - else: - raise NotImplementedError - - last_operation_id = int(operation_id) \ No newline at end of file diff --git a/pychunkedgraph/app/__init__.py b/pychunkedgraph/app/__init__.py index e5c35e48c..3e938628b 100644 --- a/pychunkedgraph/app/__init__.py +++ b/pychunkedgraph/app/__init__.py @@ -1,76 +1,110 @@ -from flask import Flask -from flask.logging import default_handler -from flask_cors import CORS -import sys +import datetime +import json import logging import os +import sys import time -import json + +import pandas as pd import numpy as np -import datetime -from . import config import redis +from flask import Flask +from flask.json.provider import DefaultJSONProvider +from flask.logging import default_handler +from flask_cors import CORS from rq import Queue -# from pychunkedgraph.app import app_blueprint -from pychunkedgraph.app import cg_app_blueprint, meshing_app_blueprint from pychunkedgraph.logging import jsonformatter -# from pychunkedgraph.app import manifest_app_blueprint -os.environ['TRAVIS_BRANCH'] = "IDONTKNOWWHYINEEDTHIS" + +from . import config +from .meshing.legacy.routes import bp as meshing_api_legacy +from .meshing.v1.routes import bp as meshing_api_v1 +from .segmentation.legacy.routes import bp as segmentation_api_legacy +from .segmentation.v1.routes import bp as segmentation_api_v1 +from .segmentation.generic.routes import bp as generic_api +from .app_utils import get_instance_folder_path class CustomJsonEncoder(json.JSONEncoder): + def __init__(self, int64_as_str=False, **kwargs): + super().__init__(**kwargs) + self.int64_as_str = int64_as_str + def default(self, obj): if isinstance(obj, np.ndarray): + if self.int64_as_str and obj.dtype.type in (np.int64, np.uint64): + return obj.astype(str).tolist() return obj.tolist() elif isinstance(obj, np.generic): + if self.int64_as_str and obj.dtype.type in (np.int64, np.uint64): + return obj.astype(str).item() return obj.item() elif isinstance(obj, datetime.datetime): return obj.__str__() + elif isinstance(obj, pd.DataFrame): + return obj.to_json() return json.JSONEncoder.default(self, obj) +class CustomJSONProvider(DefaultJSONProvider): + def dumps(self, obj, **kwargs): + return super().dumps(obj, default=None, cls=CustomJsonEncoder, **kwargs) + + def create_app(test_config=None): - app = Flask(__name__) - app.json_encoder = CustomJsonEncoder + app = Flask( + __name__, + instance_path=get_instance_folder_path(), + instance_relative_config=True, + ) + app.json = CustomJSONProvider(app) - CORS(app, expose_headers='WWW-Authenticate') + CORS(app, expose_headers="WWW-Authenticate") configure_app(app) if test_config is not None: app.config.update(test_config) - app.register_blueprint(cg_app_blueprint.bp) - app.register_blueprint(meshing_app_blueprint.bp) - # app.register_blueprint(manifest_app_blueprint.bp) + app.register_blueprint(generic_api) + + app.register_blueprint(meshing_api_legacy) + app.register_blueprint(meshing_api_v1) + + app.register_blueprint(segmentation_api_legacy) + app.register_blueprint(segmentation_api_v1) return app def configure_app(app): # Load logging scheme from config.py - app_settings = os.getenv('APP_SETTINGS') + app_settings = os.getenv("APP_SETTINGS") if not app_settings: app.config.from_object(config.BaseConfig) else: app.config.from_object(app_settings) - - + app.config.from_pyfile("config.cfg", silent=True) # Configure logging # handler = logging.FileHandler(app.config['LOGGING_LOCATION']) handler = logging.StreamHandler(sys.stdout) - handler.setLevel(app.config['LOGGING_LEVEL']) + handler.setLevel(app.config["LOGGING_LEVEL"]) formatter = jsonformatter.JsonFormatter( - fmt=app.config['LOGGING_FORMAT'], - datefmt=app.config['LOGGING_DATEFORMAT']) + fmt=app.config["LOGGING_FORMAT"], datefmt=app.config["LOGGING_DATEFORMAT"] + ) formatter.converter = time.gmtime handler.setFormatter(formatter) app.logger.removeHandler(default_handler) app.logger.addHandler(handler) - app.logger.setLevel(app.config['LOGGING_LEVEL']) + app.logger.setLevel(app.config["LOGGING_LEVEL"]) app.logger.propagate = False - if app.config['USE_REDIS_JOBS']: - app.redis = redis.Redis.from_url(app.config['REDIS_URL']) - app.test_q = Queue('test' ,connection=app.redis) \ No newline at end of file + if app.config["USE_REDIS_JOBS"]: + app.redis = redis.Redis.from_url(app.config["REDIS_URL"]) + app.test_q = Queue("test", connection=app.redis) + with app.app_context(): + from ..ingest.rq_cli import init_rq_cmds + from ..ingest.cli import init_ingest_cmds + + init_rq_cmds(app) + init_ingest_cmds(app) diff --git a/pychunkedgraph/app/app_test.py b/pychunkedgraph/app/app_test.py deleted file mode 100644 index 8bdff8f1c..000000000 --- a/pychunkedgraph/app/app_test.py +++ /dev/null @@ -1,106 +0,0 @@ -import time -import pytest -from pychunkedgraph.app import create_app - - -@pytest.fixture -def app(): - app = create_app( - { - 'TESTING': True, - 'BIGTABLE_CONFIG': { - 'emulate': True - } - } - ) - yield app - - -@pytest.fixture -def client(app): - return app.test_client() - -# TODO convert this to an actual self contained test with emulated backend -# and use app factory to create testing app and client objects - -# TODO setup fixture that puts data in backend before running client tests - - -def request(test_client, op, body, post=True): - - if post: - url = '/1.0/segment/{1}/{2}'.format(body[0], op) - body = [] - else: - url = '/1.0/graph/{1}'.format(op) - - print(url) - time_start = time.time() - response = test_client.get(url, verify=False, json=body) - - dt = (time.time() - time_start) * 1000 - print("%.3fms" % dt) - - return response - - -def get_root(client, atomic_id): - body = [str(atomic_id), 0, 0, 0] - - print(body) - r = request(client, "root", body, post=False) - - print(r.content) - return r - - -def get_children(client, parent_id): - body = [str(parent_id), 0, 0, 0] - - print(body) - r = request(client, "children", body, post=True) - - # print(r.content) - return r - - -def get_leaves(client, atomic_id): - body = [str(atomic_id), 0, 0, 0] - - print(body) - r = request(client, "leaves", body, post=True) - - # print(r.content) - return r - - -def get_leaves_from_leave(atomic_id): - body = [str(atomic_id), 0, 0, 0] - - print(body) - r = request("leaves_from_leave", body, post=True) - - # print(r.content) - return r - - -def merge(atomic_ids): - body = [[str(atomic_ids[0]), 0, 0, 0], - [str(atomic_ids[1]), 0, 0, 0]] - - print(body) - r = request("merge", body, post=False) - - # print(r.content) - return r - - -def split(atomic_ids): - body = [[str(atomic_ids[0]), 0, 0, 0], - [str(atomic_ids[1]), 0, 0, 0]] - - print(body) - r = request("split", body, post=False) - - # print(r.content) - return r diff --git a/pychunkedgraph/app/app_utils.py b/pychunkedgraph/app/app_utils.py index 84b73e23e..b46e4b192 100644 --- a/pychunkedgraph/app/app_utils.py +++ b/pychunkedgraph/app/app_utils.py @@ -1,96 +1,201 @@ -from flask import current_app -from google.auth import credentials, default as default_creds -from google.cloud import bigtable, datastore +# pylint: disable=invalid-name, missing-docstring, logging-fstring-interpolation + +import os +from typing import Sequence +from time import mktime +from functools import wraps -import sys import numpy as np -import logging -import time -import redis -import functools +import networkx as nx +import requests +from flask import current_app, json, request +from scipy import spatial +from werkzeug.datastructures import ImmutableMultiDict -from pychunkedgraph.logging import jsonformatter, flask_log_db -from pychunkedgraph.backend import chunkedgraph +from pychunkedgraph import __version__ +from pychunkedgraph.graph import ChunkedGraph +from pychunkedgraph.graph.client import get_default_client_info +from pychunkedgraph.graph import exceptions as cg_exceptions -cache = {} +PCG_CACHE = {} -class DoNothingCreds(credentials.Credentials): - def refresh(self, request): - pass +def get_app_base_path(): + return os.path.dirname(os.path.realpath(__file__)) -def get_bigtable_client(config): - project_id = config.get('project_id', 'pychunkedgraph') - if config.get('emulate', False): - credentials = DoNothingCreds() - else: - credentials, project_id = default_creds() +def get_instance_folder_path(): + return os.path.join(get_app_base_path(), "instance") - client = bigtable.Client(admin=True, - project=project_id, - credentials=credentials) - return client +def remap_public(func=None, *, edit=False, check_node_ids=False): + def mydecorator(f): + @wraps(f) + def decorated_function(*args, **kwargs): + virtual_tables = current_app.config.get("VIRTUAL_TABLES", None) -def get_datastore_client(config): - project_id = config.get('project_id', 'pychunkedgraph') + # if not virtual configuration just return + if virtual_tables is None: + return f(*args, **kwargs) + table_id = kwargs.get("table_id", None) + http_args = request.args.to_dict() - if config.get('emulate', False): - credentials = DoNothingCreds() - else: - credentials, project_id = default_creds() + if table_id is None: + # then no table remapping necessary + return f(*args, **kwargs) + if not table_id in virtual_tables: + # if table table_id isn't in virtual + # tables then just return + return f(*args, **kwargs) + else: + # then we have a virtual table + if edit: + raise cg_exceptions.Unauthorized( + "No edits allowed on virtual tables" + ) + # and we want to remap the table name + new_table = virtual_tables[table_id]["table_id"] + kwargs["table_id"] = new_table + v_timestamp = virtual_tables[table_id]["timestamp"] + v_timetamp_float = mktime(v_timestamp.timetuple()) + + # we want to fix timestamp parameters too + def ceiling_timestamp(argname): + old_arg = http_args.get(argname, None) + if old_arg is not None: + old_arg = float(old_arg) + # if they specified a timestamp + # enforce its less than the cap + if old_arg > v_timetamp_float: + http_args[argname] = v_timetamp_float + else: + # if they omit the timestamp, it defaults to "now" + # so we should cap it at the virtual timestamp + http_args[argname] = v_timetamp_float + + ceiling_timestamp("timestamp") + ceiling_timestamp("timestamp_future") + + request.args = ImmutableMultiDict(http_args) + + # we also want to check for endpoints + # which ask for info about IDs and + # restrict such calls to IDs that are valid + # before the timestamp cap for this virtual table + cg = get_cg(new_table) + + def assert_node_prop(prop): + node_id = kwargs.get(prop, None) + if node_id is not None: + node_id = int(node_id) + # check if this root_id is valid at this timestamp + timestamp = cg.get_node_timestamps([node_id]) + if not np.all(timestamp < np.datetime64(v_timestamp)): + raise cg_exceptions.Unauthorized( + "root_id not valid at timestamp" + ) + + assert_node_prop("root_id") + assert_node_prop("node_id") - client = datastore.Client(project=project_id, credentials=credentials) - return client + # some endpoints post node_ids as json, so we have to check there + # as well if the endpoint configured us to. + if check_node_ids: + node_ids = np.array( + json.loads(request.data)["node_ids"], dtype=np.uint64 + ) + timestamps = cg.get_node_timestamps(node_ids) + if not np.all(timestamps < np.datetime64(v_timestamp)): + raise cg_exceptions.Unauthorized( + "node_ids are all not valid at timestamp" + ) + return f(*args, **kwargs) -def get_cg(table_id): - assert table_id.startswith("fly") or table_id.startswith("golden") or \ - table_id.startswith("pinky100_rv") + return decorated_function - if table_id not in cache: - instance_id = current_app.config['CHUNKGRAPH_INSTANCE_ID'] - client = get_bigtable_client(current_app.config) + if func: + return mydecorator(func) + else: + return mydecorator + + +def jsonify_with_kwargs(data, as_response=True, **kwargs): + kwargs.setdefault("separators", (",", ":")) - # Create ChunkedGraph logging - logger = logging.getLogger(f"{instance_id}/{table_id}") - logger.setLevel(current_app.config['LOGGING_LEVEL']) + if current_app.json.compact == False or current_app.debug: + kwargs["indent"] = 2 + kwargs["separators"] = (", ", ": ") + + resp = json.dumps(data, **kwargs) + if as_response: + return current_app.response_class( + resp + "\n", mimetype=current_app.json.mimetype + ) + else: + return resp - # prevent duplicate logs from Flasks(?) parent logger - logger.propagate = False - handler = logging.StreamHandler(sys.stdout) - handler.setLevel(current_app.config['LOGGING_LEVEL']) - formatter = jsonformatter.JsonFormatter( - fmt=current_app.config['LOGGING_FORMAT'], - datefmt=current_app.config['LOGGING_DATEFORMAT']) - formatter.converter = time.gmtime - handler.setFormatter(formatter) +def ensure_correct_version(cg: ChunkedGraph) -> bool: + current_major_version = int(__version__.split(".", maxsplit=1)[0]) + try: + graph_major_version = int(cg.version.split(".")[0]) + valid = graph_major_version == current_major_version + assert valid, f"v{cg.version} not supported, server version {__version__}." + return True + except (AttributeError, TypeError): + # graph not versioned, later checked if whitelisted + return False - logger.addHandler(handler) - # Create ChunkedGraph - cache[table_id] = chunkedgraph.ChunkedGraph(table_id=table_id, - instance_id=instance_id, - client=client, - logger=logger) +def get_cg(table_id, skip_cache: bool = False): current_app.table_id = table_id - return cache[table_id] + if skip_cache is False: + try: + return PCG_CACHE[table_id] + except KeyError: + pass + cg = ChunkedGraph(graph_id=table_id, client_info=get_default_client_info()) + version_valid = ensure_correct_version(cg) + if version_valid: + PCG_CACHE[table_id] = cg + return cg + + if cg.graph_id in current_app.config["PCG_GRAPH_IDS"]: + current_app.logger.warning(f"Serving whitelisted graph {cg.graph_id}.") + PCG_CACHE[table_id] = cg + return cg + raise ValueError(f"Graph {cg.graph_id} not supported.") + + +def toboolean(value): + """Transform value to boolean type. + :param value: bool/int/str + :return: bool + :raises: ValueError, if value is not boolean. + """ + if not value: + raise ValueError("Can't convert null to boolean") -def get_log_db(table_id): - if 'log_db' not in cache: - client = get_datastore_client(current_app.config) - cache["log_db"] = flask_log_db.FlaskLogDatabase(table_id, - client=client) + if isinstance(value, bool): + return value + try: + value = value.lower() + except Exception as exc: + raise ValueError(f"Can't convert {value} to boolean: {exc}") from exc - return cache["log_db"] + if value in ("true", "1"): + return True + if value in ("false", "0"): + return False + + raise ValueError(f"Can't convert {value} to boolean") def tobinary(ids): - """ Transform id(s) to binary format + """Transform id(s) to binary format :param ids: uint64 or list of uint64s :return: binary @@ -99,9 +204,86 @@ def tobinary(ids): def tobinary_multiples(arr): - """ Transform id(s) to binary format + """Transform id(s) to binary format :param arr: list of uint64 or list of uint64s :return: binary """ return [np.array(arr_i).tobytes() for arr_i in arr] + + +def handle_supervoxel_id_lookup( + cg, coordinates: Sequence[Sequence[int]], node_ids: Sequence[np.uint64] +) -> Sequence[np.uint64]: + """ + Helper to lookup supervoxel ids. + This takes care of grouping coordinates. + """ + + def ccs(coordinates_nm_): + graph = nx.Graph() + dist_mat = spatial.distance.cdist(coordinates_nm_, coordinates_nm_) + for edge in np.array(np.where(dist_mat < 1000)).T: + graph.add_edge(*edge) + ccs = [np.array(list(cc)) for cc in nx.connected_components(graph)] + return ccs + + coordinates = np.array(coordinates, dtype=int) + coordinates_nm = coordinates * cg.meta.resolution + max_dist_steps = np.array([4, 8, 14, 28], dtype=float) * np.mean(cg.meta.resolution) + + node_ids = np.array(node_ids, dtype=np.uint64) + if len(coordinates.shape) != 2: + raise cg_exceptions.BadRequest( + f"Could not determine supervoxel ID for coordinates " + f"{coordinates} - Validation stage." + ) + + atomic_ids = np.zeros(len(coordinates), dtype=np.uint64) + for node_id in np.unique(node_ids): + node_id_m = node_ids == node_id + for cc in ccs(coordinates_nm[node_id_m]): + m_ids = np.where(node_id_m)[0][cc] + + for max_dist_nm in max_dist_steps: + atomic_ids_sub = cg.get_atomic_ids_from_coords( + coordinates[m_ids], parent_id=node_id, max_dist_nm=max_dist_nm + ) + if atomic_ids_sub is not None: + break + if atomic_ids_sub is None: + raise cg_exceptions.BadRequest( + f"Could not determine supervoxel ID for coordinates " + f"{coordinates} - Lookup stage." + ) + atomic_ids[m_ids] = atomic_ids_sub + return atomic_ids + + +def get_username_dict(user_ids, auth_token) -> dict: + AUTH_URL = os.environ.get("AUTH_URL", None) + if AUTH_URL is None: + raise cg_exceptions.ChunkedGraphError("No AUTH_URL defined") + + users_request = requests.get( + f"https://{AUTH_URL}/api/v1/username?id={','.join(map(str, np.unique(user_ids)))}", + headers={"authorization": "Bearer " + auth_token}, + timeout=5, + ) + return {x["id"]: x["name"] for x in users_request.json()} + + +def get_userinfo_dict(user_ids, auth_token): + AUTH_URL = os.environ.get("AUTH_URL", None) + + if AUTH_URL is None: + raise cg_exceptions.ChunkedGraphError("No AUTH_URL defined") + + users_request = requests.get( + f"https://{AUTH_URL}/api/v1/user?id={','.join(map(str, np.unique(user_ids)))}", + headers={"authorization": "Bearer " + auth_token}, + timeout=5, + ) + return {x["id"]: x["name"] for x in users_request.json()}, { + x["id"]: x["pi"] for x in users_request.json() + } diff --git a/pychunkedgraph/app/cg_app_blueprint.py b/pychunkedgraph/app/cg_app_blueprint.py deleted file mode 100644 index 96bff589d..000000000 --- a/pychunkedgraph/app/cg_app_blueprint.py +++ /dev/null @@ -1,620 +0,0 @@ -from flask import Blueprint, request, make_response, jsonify, current_app,\ - redirect, url_for, after_this_request, Response, g - -import json -import numpy as np -import time -from datetime import datetime -from pytz import UTC -import traceback -import collections -import requests -import threading - -from pychunkedgraph.app import app_utils, meshing_app_blueprint -from pychunkedgraph.backend import chunkedgraph_exceptions as cg_exceptions, \ - chunkedgraph_comp as cg_comp -from middle_auth_client import auth_required, auth_requires_roles - -__version__ = 'fafb.1.21' -bp = Blueprint('pychunkedgraph', __name__, url_prefix="/segmentation") - -# ------------------------------- -# ------ Access control and index -# ------------------------------- - - -@bp.route('/') -@bp.route("/index") -def index(): - return "PyChunkedGraph Server -- " + __version__ - - -@bp.route -def home(): - resp = make_response() - resp.headers['Access-Control-Allow-Origin'] = '*' - acah = "Origin, X-Requested-With, Content-Type, Accept" - resp.headers["Access-Control-Allow-Headers"] = acah - resp.headers["Access-Control-Allow-Methods"] = "POST, GET, OPTIONS" - resp.headers["Connection"] = "keep-alive" - return resp - - -# ------------------------------- -# ------ Measurements and Logging -# ------------------------------- - -@bp.before_request -def before_request(): - current_app.request_start_time = time.time() - current_app.request_start_date = datetime.utcnow() - - -@bp.after_request -def after_request(response): - dt = (time.time() - current_app.request_start_time) * 1000 - - current_app.logger.debug("Response time: %.3fms" % dt) - - try: - log_db = app_utils.get_log_db(current_app.table_id) - log_db.add_success_log(user_id="", user_ip="", - request_time=current_app.request_start_date, - response_time=dt, url=request.url, - request_data=request.data, - request_type=current_app.request_type) - except: - current_app.logger.debug("LogDB entry not successful") - - return response - - -@bp.errorhandler(Exception) -def unhandled_exception(e): - status_code = 500 - response_time = (time.time() - current_app.request_start_time) * 1000 - user_ip = str(request.remote_addr) - tb = traceback.format_exception(etype=type(e), value=e, - tb=e.__traceback__) - - current_app.logger.error({ - "message": str(e), - "user_id": user_ip, - "user_ip": user_ip, - "request_time": current_app.request_start_date, - "request_url": request.url, - "request_data": request.data, - "response_time": response_time, - "response_code": status_code, - "traceback": tb - }) - - resp = { - 'timestamp': current_app.request_start_date, - 'duration': response_time, - 'code': status_code, - 'message': str(e), - 'traceback': tb - } - - return jsonify(resp), status_code - - -@bp.errorhandler(cg_exceptions.ChunkedGraphAPIError) -def api_exception(e): - response_time = (time.time() - current_app.request_start_time) * 1000 - user_ip = str(request.remote_addr) - tb = traceback.format_exception(etype=type(e), value=e, - tb=e.__traceback__) - - current_app.logger.error({ - "message": str(e), - "user_id": user_ip, - "user_ip": user_ip, - "request_time": current_app.request_start_date, - "request_url": request.url, - "request_data": request.data, - "response_time": response_time, - "response_code": e.status_code.value, - "traceback": tb - }) - - resp = { - 'timestamp': current_app.request_start_date, - 'duration': response_time, - 'code': e.status_code.value, - 'message': str(e) - } - - return jsonify(resp), e.status_code.value - - -# ------------------- -# ------ Applications -# ------------------- - - -@bp.route("/sleep/") -def sleep_me(sleep): - current_app.request_type = "sleep" - - time.sleep(sleep) - return "zzz... {} ... awake".format(sleep) - - -@bp.route('/1.0//info', methods=['GET']) -def handle_info(table_id): - current_app.request_type = "info" - - cg = app_utils.get_cg(table_id) - - return jsonify(cg.dataset_info) - -### GET ROOT ------------------------------------------------------------------- - -@bp.route('/1.0//graph/root', methods=['POST', 'GET']) -def handle_root_1(table_id): - atomic_id = np.uint64(json.loads(request.data)[0]) - - # Convert seconds since epoch to UTC datetime - try: - timestamp = float(request.args.get('timestamp', time.time())) - timestamp = datetime.fromtimestamp(timestamp, UTC) - except (TypeError, ValueError) as e: - raise(cg_exceptions.BadRequest("Timestamp parameter is not a valid" - " unix timestamp")) - - return handle_root_main(table_id, atomic_id, timestamp) - - -@bp.route('/1.0//graph//root', methods=['POST', 'GET']) -def handle_root_2(table_id, atomic_id): - - # Convert seconds since epoch to UTC datetime - try: - timestamp = float(request.args.get('timestamp', time.time())) - timestamp = datetime.fromtimestamp(timestamp, UTC) - except (TypeError, ValueError) as e: - raise(cg_exceptions.BadRequest("Timestamp parameter is not a valid" - " unix timestamp")) - - return handle_root_main(table_id, np.uint64(atomic_id), timestamp) - - -def handle_root_main(table_id, atomic_id, timestamp): - current_app.request_type = "root" - - # Call ChunkedGraph - cg = app_utils.get_cg(table_id) - root_id = cg.get_root(np.uint64(atomic_id), time_stamp=timestamp) - - # Return binary - return app_utils.tobinary(root_id) - - -### MERGE ---------------------------------------------------------------------- - -@bp.route('/1.0//graph/merge', methods=['POST', 'GET']) -@auth_requires_roles('edit_all') -def handle_merge(table_id): - current_app.request_type = "merge" - - nodes = json.loads(request.data) - user_id = str(g.auth_user['id']) - - current_app.logger.debug(nodes) - assert len(nodes) == 2 - - # Call ChunkedGraph - cg = app_utils.get_cg(table_id) - - atomic_edge = [] - coords = [] - for node in nodes: - node_id = node[0] - x, y, z = node[1:] - coordinate = np.array([x, y, z]) / cg.segmentation_resolution - - atomic_id = cg.get_atomic_id_from_coord(coordinate[0], - coordinate[1], - coordinate[2], - parent_id=np.uint64(node_id)) - - if atomic_id is None: - raise cg_exceptions.BadRequest( - f"Could not determine supervoxel ID for coordinates " - f"{coordinate}." - ) - - coords.append(coordinate) - atomic_edge.append(atomic_id) - - # Protection from long range mergers - chunk_coord_delta = cg.get_chunk_coordinates(atomic_edge[0]) - \ - cg.get_chunk_coordinates(atomic_edge[1]) - - if np.any(np.abs(chunk_coord_delta) > 3): - raise cg_exceptions.BadRequest( - "Chebyshev distance between merge points exceeded allowed maximum " - "(3 chunks).") - - try: - ret = cg.add_edges( - user_id=user_id, - atomic_edges=np.array(atomic_edge, dtype=np.uint64), - source_coord=coords[:1], - sink_coord=coords[1:], - ) - - except cg_exceptions.LockingError as e: - raise cg_exceptions.InternalServerError("Could not acquire root lock for merge operation.") - except cg_exceptions.PreconditionError as e: - raise cg_exceptions.BadRequest(str(e)) - - if ret.new_root_ids is None: - raise cg_exceptions.InternalServerError("Could not merge selected supervoxel.") - - current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids)) - - if len(ret.new_lvl2_ids) > 0: - t = threading.Thread( - target=meshing_app_blueprint._remeshing, - args=(cg.get_serialized_info(), ret.new_lvl2_ids), - ) - t.start() - - # NOTE: JS can't safely read integers larger than 2^53 - 1 - resp = { - "operation_id": ret.operation_id, - "operation_id_str": str(ret.operation_id), - "new_root_ids": ret.new_root_ids, - "new_root_ids_str": list(map(str, ret.new_root_ids)), - } - return jsonify(resp) - - -### SPLIT ---------------------------------------------------------------------- - -@bp.route('/1.0//graph/split', methods=['POST', 'GET']) -@auth_requires_roles('edit_all') -def handle_split(table_id): - current_app.request_type = "split" - - data = json.loads(request.data) - user_id = str(g.auth_user['id']) - - current_app.logger.debug(data) - - # Call ChunkedGraph - cg = app_utils.get_cg(table_id) - - data_dict = {} - for k in ["sources", "sinks"]: - data_dict[k] = collections.defaultdict(list) - - for node in data[k]: - node_id = node[0] - x, y, z = node[1:] - coordinate = np.array([x, y, z]) / cg.segmentation_resolution - - atomic_id = cg.get_atomic_id_from_coord(coordinate[0], - coordinate[1], - coordinate[2], - parent_id=np.uint64( - node_id)) - - if atomic_id is None: - raise cg_exceptions.BadRequest( - f"Could not determine supervoxel ID for coordinates " - f"{coordinate}.") - - data_dict[k]["id"].append(atomic_id) - data_dict[k]["coord"].append(coordinate) - - current_app.logger.debug(data_dict) - - try: - ret = cg.remove_edges( - user_id=user_id, - source_ids=data_dict["sources"]["id"], - sink_ids=data_dict["sinks"]["id"], - source_coords=data_dict["sources"]["coord"], - sink_coords=data_dict["sinks"]["coord"], - mincut=True, - ) - - except cg_exceptions.LockingError as e: - raise cg_exceptions.InternalServerError("Could not acquire root lock for split operation.") - except cg_exceptions.PreconditionError as e: - raise cg_exceptions.BadRequest(str(e)) - - if ret.new_root_ids is None: - raise cg_exceptions.InternalServerError("Could not split selected segment groups.") - - current_app.logger.debug(("after split:", ret.new_root_ids)) - current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids)) - - if len(ret.new_lvl2_ids) > 0: - t = threading.Thread( - target=meshing_app_blueprint._remeshing, - args=(cg.get_serialized_info(), ret.new_lvl2_ids), - ) - t.start() - - # NOTE: JS can't safely read integers larger than 2^53 - 1 - resp = { - "operation_id": ret.operation_id, - "operation_id_str": str(ret.operation_id), - "new_root_ids": ret.new_root_ids, - "new_root_ids_str": list(map(str, ret.new_root_ids)), - } - return jsonify(resp) - - -### UNDO ---------------------------------------------------------------------- - - -@bp.route("/1.0//graph/undo", methods=["POST"]) -@auth_requires_roles("edit_all") -def handle_undo(table_id): - current_app.request_type = "undo" - - data = json.loads(request.data) - user_id = str(g.auth_user["id"]) - - current_app.logger.debug(data) - - # Call ChunkedGraph - cg = app_utils.get_cg(table_id) - operation_id = np.uint64(data["operation_id"]) - - try: - ret = cg.undo(user_id=user_id, operation_id=operation_id) - except cg_exceptions.LockingError as e: - raise cg_exceptions.InternalServerError("Could not acquire root lock for undo operation.") - except (cg_exceptions.PreconditionError, cg_exceptions.PostconditionError) as e: - raise cg_exceptions.BadRequest(str(e)) - - current_app.logger.debug(("after undo:", ret.new_root_ids)) - current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids)) - - if ret.new_lvl2_ids.size > 0: - t = threading.Thread( - target=meshing_app_blueprint._remeshing, - args=(cg.get_serialized_info(), ret.new_lvl2_ids), - ) - t.start() - - # NOTE: JS can't safely read integers larger than 2^53 - 1 - resp = { - "operation_id": ret.operation_id, - "operation_id_str": str(ret.operation_id), - "new_root_ids": ret.new_root_ids, - "new_root_ids_str": list(map(str, ret.new_root_ids)), - } - return jsonify(resp) - - -### REDO ---------------------------------------------------------------------- - - -@bp.route("/1.0//graph/redo", methods=["POST"]) -@auth_requires_roles("edit_all") -def handle_redo(table_id): - current_app.request_type = "redo" - - data = json.loads(request.data) - user_id = str(g.auth_user["id"]) - - current_app.logger.debug(data) - - # Call ChunkedGraph - cg = app_utils.get_cg(table_id) - operation_id = np.uint64(data["operation_id"]) - - try: - ret = cg.redo(user_id=user_id, operation_id=operation_id) - except cg_exceptions.LockingError as e: - raise cg_exceptions.InternalServerError("Could not acquire root lock for redo operation.") - except (cg_exceptions.PreconditionError, cg_exceptions.PostconditionError) as e: - raise cg_exceptions.BadRequest(str(e)) - - current_app.logger.debug(("after redo:", ret.new_root_ids)) - current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids)) - - if ret.new_lvl2_ids.size > 0: - t = threading.Thread( - target=meshing_app_blueprint._remeshing, - args=(cg.get_serialized_info(), ret.new_lvl2_ids), - ) - t.start() - - # NOTE: JS can't safely read integers larger than 2^53 - 1 - resp = { - "operation_id": ret.operation_id, - "operation_id_str": str(ret.operation_id), - "new_root_ids": ret.new_root_ids, - "new_root_ids_str": list(map(str, ret.new_root_ids)), - } - return jsonify(resp) - - -### CHILDREN ------------------------------------------------------------------- - -@bp.route('/1.0//segment//children', - methods=['POST', 'GET']) -def handle_children(table_id, parent_id): - current_app.request_type = "children" - - cg = app_utils.get_cg(table_id) - - parent_id = np.uint64(parent_id) - layer = cg.get_chunk_layer(parent_id) - - if layer > 1: - children = cg.get_children(parent_id) - else: - children = np.array([]) - - # Return binary - return app_utils.tobinary(children) - - -### LEAVES --------------------------------------------------------------------- - -@bp.route('/1.0//segment//leaves', methods=['POST', 'GET']) -def handle_leaves(table_id, root_id): - current_app.request_type = "leaves" - - if "bounds" in request.args: - bounds = request.args["bounds"] - bounding_box = np.array([b.split("-") for b in bounds.split("_")], - dtype=np.int).T - else: - bounding_box = None - - # Call ChunkedGraph - cg = app_utils.get_cg(table_id) - atomic_ids = cg.get_subgraph_nodes(int(root_id), - bounding_box=bounding_box, - bb_is_coordinate=True) - - # Return binary - return app_utils.tobinary(atomic_ids) - - -### LEAVES FROM LEAVES --------------------------------------------------------- - -@bp.route('/1.0//segment//leaves_from_leave', - methods=['POST', 'GET']) -def handle_leaves_from_leave(table_id, atomic_id): - current_app.request_type = "leaves_from_leave" - - if "bounds" in request.args: - bounds = request.args["bounds"] - bounding_box = np.array([b.split("-") for b in bounds.split("_")], - dtype=np.int).T - else: - bounding_box = None - - # Call ChunkedGraph - cg = app_utils.get_cg(table_id) - root_id = cg.get_root(int(atomic_id)) - - atomic_ids = cg.get_subgraph_nodes(root_id, - bounding_box=bounding_box, - bb_is_coordinate=True) - # Return binary - return app_utils.tobinary(np.concatenate([np.array([root_id]), atomic_ids])) - - -### SUBGRAPH ------------------------------------------------------------------- - -@bp.route('/1.0//segment//subgraph', methods=['POST', 'GET']) -def handle_subgraph(table_id, root_id): - current_app.request_type = "subgraph" - - if "bounds" in request.args: - bounds = request.args["bounds"] - bounding_box = np.array([b.split("-") for b in bounds.split("_")], - dtype=np.int).T - else: - bounding_box = None - - # Call ChunkedGraph - cg = app_utils.get_cg(table_id) - atomic_edges = cg.get_subgraph_edges(int(root_id), - bounding_box=bounding_box, - bb_is_coordinate=True)[0] - # Return binary - return app_utils.tobinary(atomic_edges) - - -### CHANGE LOG ----------------------------------------------------------------- - -@bp.route('/1.0//segment//change_log', - methods=["POST", "GET"]) -def change_log(table_id, root_id): - current_app.request_type = "change_log" - - try: - time_stamp_past = float(request.args.get('timestamp', 0)) - time_stamp_past = datetime.fromtimestamp(time_stamp_past, UTC) - except (TypeError, ValueError) as e: - raise(cg_exceptions.BadRequest("Timestamp parameter is not a valid" - " unix timestamp")) - - # Call ChunkedGraph - cg = app_utils.get_cg(table_id) - - change_log = cg.get_change_log(root_id=np.uint64(root_id), - correct_for_wrong_coord_type=True, - time_stamp_past=time_stamp_past) - - return jsonify(change_log) - - -@bp.route('/1.0//segment//merge_log', - methods=["POST", "GET"]) -def merge_log(table_id, root_id): - current_app.request_type = "merge_log" - - try: - time_stamp_past = float(request.args.get('timestamp', 0)) - time_stamp_past = datetime.fromtimestamp(time_stamp_past, UTC) - except (TypeError, ValueError) as e: - raise(cg_exceptions.BadRequest("Timestamp parameter is not a valid" - " unix timestamp")) - - # Call ChunkedGraph - cg = app_utils.get_cg(table_id) - - change_log = cg.get_change_log(root_id=np.uint64(root_id), - correct_for_wrong_coord_type=True, - time_stamp_past=time_stamp_past) - - for k in list(change_log.keys()): - if not "merge" in k: - del change_log[k] - continue - - return jsonify(change_log) - - -@bp.route('/1.0//graph/oldest_timestamp', methods=["POST", "GET"]) -def oldest_timestamp(table_id): - current_app.request_type = "timestamp" - - cg = app_utils.get_cg(table_id) - - oldest_log_row = cg.read_first_log_row() - - if oldest_log_row is None: - raise Exception("No log row found") - - return jsonify(list(oldest_log_row.values())[0][0].timestamp) - - -### CONTACT SITES -------------------------------------------------------------- - -@bp.route('/1.0//segment//contact_sites', - methods=["POST", "GET"]) -def handle_contact_sites(table_id, root_id): - partners = request.args.get('partners', False) - - if "bounds" in request.args: - bounds = request.args["bounds"] - bounding_box = np.array([b.split("-") for b in bounds.split("_")], - dtype=np.int).T - else: - bounding_box = None - - # Call ChunkedGraph - cg = app_utils.get_cg(table_id) - - cs_dict = cg_comp.get_contact_sites(cg, np.uint64(root_id), - bounding_box = bounding_box, - compute_partner=partners) - - return jsonify(cs_dict) \ No newline at end of file diff --git a/pychunkedgraph/app/common.py b/pychunkedgraph/app/common.py new file mode 100644 index 000000000..237e11fc0 --- /dev/null +++ b/pychunkedgraph/app/common.py @@ -0,0 +1,150 @@ +# pylint: disable=invalid-name, missing-docstring, unspecified-encoding + +import os +import json +import time +import traceback +from datetime import datetime + +from cloudvolume import compression +from google.api_core.exceptions import GoogleAPIError +from flask import current_app, g, jsonify, request + +from pychunkedgraph.logging.log_db import get_log_db + +USER_NOT_FOUND = "-1" + +ENABLE_LOGS = os.environ.get("PCG_SERVER_ENABLE_LOGS", "") != "" +LOG_LEAVES_MANY = os.environ.get("PCG_SERVER_LOGS_LEAVES_MANY", "") != "" + + +def _log_request(response_time): + try: + current_app.user_id = g.auth_user["id"] + except (AttributeError, KeyError): + current_app.user_id = USER_NOT_FOUND + + if ENABLE_LOGS is False: + return + + if LOG_LEAVES_MANY is False and "leaves_many" in request.path: + return + + try: + if current_app.table_id is not None: + log_db = get_log_db(current_app.table_id) + args = dict(request.args) # request.args is ImmutableMultiDict + args.pop("middle_auth_token", None) + log_db.log_endpoint( + path=request.path, + endpoint=request.endpoint, + args=json.dumps(args), + user_id=current_app.user_id, + operation_id=current_app.operation_id, + request_ts=current_app.request_start_date, + response_time=response_time, + ) + except GoogleAPIError as e: + current_app.logger.error(f"LogDB entry not successful: GoogleAPIError {e}") + + +def before_request(): + current_app.request_start_time = time.time() + current_app.request_start_date = datetime.utcnow() + try: + current_app.user_id = g.auth_user["id"] + except (AttributeError, KeyError): + current_app.user_id = USER_NOT_FOUND + current_app.table_id = None + current_app.operation_id = None + current_app.request_type = None + content_encoding = request.headers.get("Content-Encoding", "") + if "gzip" in content_encoding.lower(): + request.data = compression.decompress(request.data, "gzip") + + +def after_request(response): + response_time = (time.time() - current_app.request_start_time) * 1000 + accept_encoding = request.headers.get("Accept-Encoding", "") + + _log_request(response_time) + + if "gzip" not in accept_encoding.lower(): + return response + + response.direct_passthrough = False + if ( + response.status_code < 200 + or response.status_code >= 300 + or "Content-Encoding" in response.headers + ): + return response + + response.data = compression.gzip_compress(response.data) + response.headers["Content-Encoding"] = "gzip" + response.headers["Vary"] = "Accept-Encoding" + response.headers["Content-Length"] = len(response.data) + return response + + +def unhandled_exception(e): + status_code = 500 + response_time = (time.time() - current_app.request_start_time) * 1000 + user_ip = str(request.remote_addr) + tb = traceback.format_exception(e) + + _log_request(response_time) + + current_app.logger.error( + { + "message": str(e), + "user_id": user_ip, + "user_ip": user_ip, + "request_time": current_app.request_start_date, + "request_url": request.url, + "request_data": request.data, + "response_time": response_time, + "response_code": status_code, + "traceback": tb, + } + ) + + resp = { + "timestamp": current_app.request_start_date, + "duration": response_time, + "code": status_code, + "message": str(e), + "traceback": tb, + } + + return jsonify(resp), status_code + + +def api_exception(e): + response_time = (time.time() - current_app.request_start_time) * 1000 + user_ip = str(request.remote_addr) + tb = traceback.format_exception(e) + + _log_request(response_time) + + current_app.logger.error( + { + "message": str(e), + "user_id": user_ip, + "user_ip": user_ip, + "request_time": current_app.request_start_date, + "request_url": request.url, + "request_data": request.data, + "response_time": response_time, + "response_code": e.status_code.value, + "traceback": tb, + } + ) + + resp = { + "timestamp": current_app.request_start_date, + "duration": response_time, + "code": e.status_code.value, + "message": str(e), + } + return jsonify(resp), e.status_code.value diff --git a/pychunkedgraph/app/config.py b/pychunkedgraph/app/config.py index c2b629b18..2f2a92e47 100644 --- a/pychunkedgraph/app/config.py +++ b/pychunkedgraph/app/config.py @@ -1,44 +1,82 @@ +# pylint: disable=invalid-name, missing-docstring, unspecified-encoding, line-too-long, too-few-public-methods + import logging import os +import json +import datetime class BaseConfig(object): DEBUG = False TESTING = False - HOME = os.path.expanduser("~") - # TODO get this secret out of source control - SECRET_KEY = '1d94e52c-1c89-4515-b87a-f48cf3cb7f0b' LOGGING_FORMAT = '{"source":"%(name)s","time":"%(asctime)s","severity":"%(levelname)s","message":"%(message)s"}' - LOGGING_DATEFORMAT = '%Y-%m-%dT%H:%M:%S.0Z' + LOGGING_DATEFORMAT = "%Y-%m-%dT%H:%M:%S.0Z" LOGGING_LEVEL = logging.DEBUG CHUNKGRAPH_INSTANCE_ID = "pychunkedgraph" - - # TODO what is this suppose to be by default? - CHUNKGRAPH_TABLE_ID = "pinky100_sv16" - # CHUNKGRAPH_TABLE_ID = "pinky100_benchmark_v92" + PROJECT_ID = os.environ.get("PROJECT_ID", None) + CG_READ_ONLY = os.environ.get("CG_READ_ONLY", None) is not None + PCG_GRAPH_IDS = os.environ.get("PCG_GRAPH_IDS", "").split(",") USE_REDIS_JOBS = False + daf_credential_path = os.environ.get("DAF_CREDENTIALS", None) + + AUTH_TOKEN = None + if daf_credential_path is not None: + with open(daf_credential_path, "r") as f: + AUTH_TOKEN = json.load(f)["token"] + + AUTH_SERVICE_NAMESPACE = "pychunkedgraph" + VIRTUAL_TABLES = { + "minnie65_public_v117": { + "table_id": "minnie3_v1", + "timestamp": datetime.datetime( + year=2021, + month=6, + day=11, + hour=8, + minute=10, + second=0, + microsecond=253, + tzinfo=datetime.timezone.utc, + ), + } + } + class DevelopmentConfig(BaseConfig): """Development configuration.""" + USE_REDIS_JOBS = False DEBUG = True + LOGGING_LEVEL = logging.ERROR + + +class DockerDevelopmentConfig(DevelopmentConfig): + """Development configuration.""" + + USE_REDIS_JOBS = True + REDIS_HOST = os.environ.get("REDIS_HOST", "localhost") + REDIS_PORT = os.environ.get("REDIS_PORT", "6379") + REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD", "dev") + REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0" class DeploymentWithRedisConfig(BaseConfig): """Deployment configuration with Redis.""" + USE_REDIS_JOBS = True - REDIS_HOST = os.environ.get('REDIS_SERVICE_HOST') - REDIS_PORT = os.environ.get('REDIS_SERVICE_PORT') - REDIS_PASSWORD = os.environ.get('REDIS_PASSWORD') - REDIS_URL = f'redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0' + REDIS_HOST = os.environ.get("REDIS_HOST") + REDIS_PORT = os.environ.get("REDIS_PORT", "6379") + REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") + REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0" class TestingConfig(BaseConfig): """Testing configuration.""" + TESTING = True USE_REDIS_JOBS = False PRESERVE_CONTEXT_ON_EXCEPTION = False diff --git a/pychunkedgraph/admin/__init__.py b/pychunkedgraph/app/meshing/__init__.py similarity index 100% rename from pychunkedgraph/admin/__init__.py rename to pychunkedgraph/app/meshing/__init__.py diff --git a/pychunkedgraph/app/meshing/common.py b/pychunkedgraph/app/meshing/common.py new file mode 100644 index 000000000..8f1a0c20a --- /dev/null +++ b/pychunkedgraph/app/meshing/common.py @@ -0,0 +1,206 @@ +# pylint: disable=invalid-name, missing-docstring +import json +import os +import threading + +import numpy as np +import redis +from rq import Queue, Connection, Retry +from flask import Response, current_app, jsonify, make_response, request + +from pychunkedgraph import __version__ +from pychunkedgraph.app import app_utils +from pychunkedgraph.graph import chunkedgraph +from pychunkedgraph.app.meshing import tasks as meshing_tasks +from pychunkedgraph.meshing import meshgen +from pychunkedgraph.meshing.manifest import get_highest_child_nodes_with_meshes +from pychunkedgraph.meshing.manifest import get_children_before_start_layer +from pychunkedgraph.meshing.manifest import ManifestCache + + +__meshing_url_prefix__ = os.environ.get("MESHING_URL_PREFIX", "meshing") + + +def str2bool(v): + return v.lower() in ("yes", "true", "t", "1") + + +def index(): + return f"PyChunkedGraph Meshing v{__version__}" + + +def home(): + resp = make_response() + resp.headers["Access-Control-Allow-Origin"] = "*" + acah = "Origin, X-Requested-With, Content-Type, Accept" + resp.headers["Access-Control-Allow-Headers"] = acah + resp.headers["Access-Control-Allow-Methods"] = "POST, GET, OPTIONS" + resp.headers["Connection"] = "keep-alive" + return resp + + +## VALIDFRAGMENTS -------------------------------------------------------------- + + +def handle_valid_frags(table_id, node_id): + current_app.table_id = table_id + cg = app_utils.get_cg(table_id) + seg_ids = get_highest_child_nodes_with_meshes( + cg, np.uint64(node_id), stop_layer=1, verify_existence=True + ) + return app_utils.tobinary(seg_ids) + + +## MANIFEST -------------------------------------------------------------------- + + +def handle_get_manifest(table_id, node_id): + current_app.request_type = "manifest" + current_app.table_id = table_id + + data = {} + if len(request.data) > 0: + data = json.loads(request.data) + + bounding_box = None + if "bounds" in request.args: + bounds = request.args["bounds"] + bounding_box = np.array([b.split("-") for b in bounds.split("_")], dtype=int).T + + cg = app_utils.get_cg(table_id) + verify = request.args.get("verify", False) + verify = verify in ["True", "true", "1", True] + return_seg_ids = request.args.get("return_seg_ids", False) + prepend_seg_ids = request.args.get("prepend_seg_ids", False) + return_seg_ids = return_seg_ids in ["True", "true", "1", True] + prepend_seg_ids = prepend_seg_ids in ["True", "true", "1", True] + start_layer = cg.meta.custom_data.get("mesh", {}).get("max_layer", 2) + start_layer = int(request.args.get("start_layer", start_layer)) + if "start_layer" in data: + start_layer = int(data["start_layer"]) + + flexible_start_layer = None + if "flexible_start_layer" in data: + flexible_start_layer = int(data["flexible_start_layer"]) + args = ( + node_id, + verify, + return_seg_ids, + prepend_seg_ids, + start_layer, + flexible_start_layer, + bounding_box, + data, + ) + return manifest_response(cg, args) + + +def manifest_response(cg, args): + from pychunkedgraph.meshing.manifest import speculative_manifest_sharded + + ( + node_id, + verify, + return_seg_ids, + prepend_seg_ids, + start_layer, + flexible_start_layer, + bounding_box, + data, + ) = args + resp = {} + seg_ids = [] + if not verify: + seg_ids, resp["fragments"] = speculative_manifest_sharded( + cg, node_id, start_layer=start_layer, bounding_box=bounding_box + ) + + else: + seg_ids, resp["fragments"] = get_highest_child_nodes_with_meshes( + cg, + np.uint64(node_id), + start_layer=start_layer, + bounding_box=bounding_box, + ) + if prepend_seg_ids: + resp["fragments"] = [f"~{i}:{f}" for i, f in zip(seg_ids, resp["fragments"])] + if return_seg_ids: + resp["seg_ids"] = seg_ids + return _check_post_options(cg, resp, data, seg_ids) + + +def _check_post_options(cg, resp, data, seg_ids): + if app_utils.toboolean(data.get("return_seg_ids", "false")): + resp["seg_ids"] = seg_ids + if app_utils.toboolean(data.get("return_seg_id_layers", "false")): + resp["seg_id_layers"] = cg.get_chunk_layers(seg_ids) + if app_utils.toboolean(data.get("return_seg_chunk_coordinates", "false")): + resp["seg_chunk_coordinates"] = [ + cg.get_chunk_coordinates(seg_id) for seg_id in seg_ids + ] + return resp + + +## REMESHING ----------------------------------------------------- +def handle_remesh(table_id): + current_app.request_type = "remesh_enque" + current_app.table_id = table_id + is_priority = request.args.get("priority", True, type=str2bool) + is_redisjob = request.args.get("use_redis", False, type=str2bool) + + new_lvl2_ids = json.loads(request.data)["new_lvl2_ids"] + + if is_redisjob: + with Connection(redis.from_url(current_app.config["REDIS_URL"])): + + if is_priority: + retry = Retry(max=3, interval=[1, 10, 60]) + queue_name = "mesh-chunks" + else: + retry = Retry(max=3, interval=[60, 60, 60]) + queue_name = "mesh-chunks-low-priority" + q = Queue(queue_name, retry=retry, default_timeout=1200) + task = q.enqueue(meshing_tasks.remeshing, table_id, new_lvl2_ids) + + response_object = {"status": "success", "data": {"task_id": task.get_id()}} + + return jsonify(response_object), 202 + else: + new_lvl2_ids = np.array(new_lvl2_ids, dtype=np.uint64) + cg = app_utils.get_cg(table_id) + + if len(new_lvl2_ids) > 0: + t = threading.Thread( + target=_remeshing, args=(cg.get_serialized_info(), new_lvl2_ids) + ) + t.start() + + return Response(status=202) + + +def _remeshing(serialized_cg_info, lvl2_nodes): + cg = chunkedgraph.ChunkedGraph(**serialized_cg_info) + cv_mesh_dir = cg.meta.dataset_info["mesh"] + cv_unsharded_mesh_dir = cg.meta.dataset_info["mesh_metadata"]["unsharded_mesh_dir"] + cv_unsharded_mesh_path = os.path.join( + cg.meta.data_source.WATERSHED, cv_mesh_dir, cv_unsharded_mesh_dir + ) + mesh_data = cg.meta.custom_data["mesh"] + + # TODO: stop_layer and mip should be configurable by dataset + meshgen.remeshing( + cg, + lvl2_nodes, + stop_layer=mesh_data["max_layer"], + mip=mesh_data["mip"], + max_err=mesh_data["max_error"], + cv_sharded_mesh_dir=cv_mesh_dir, + cv_unsharded_mesh_path=cv_unsharded_mesh_path, + ) + + return Response(status=200) + + +def clear_manifest_cache(cg, node_id): + node_ids = get_children_before_start_layer(cg, node_id, start_layer=2) + ManifestCache(cg.graph_id).clear_fragments(node_ids) diff --git a/pychunkedgraph/app/meshing/legacy/routes.py b/pychunkedgraph/app/meshing/legacy/routes.py new file mode 100644 index 000000000..a067b78b1 --- /dev/null +++ b/pychunkedgraph/app/meshing/legacy/routes.py @@ -0,0 +1,74 @@ +from flask import Blueprint +from middle_auth_client import auth_requires_permission, auth_required + +from pychunkedgraph.app import common as app_common +from pychunkedgraph.app.meshing import common +from pychunkedgraph.graph import exceptions as cg_exceptions +from pychunkedgraph.app.app_utils import remap_public + +bp = Blueprint( + "pcg_meshing_v0", __name__, url_prefix=f"/{common.__meshing_url_prefix__}/1.0" +) + +# ------------------------------- +# ------ Access control and index +# ------------------------------- + + +@bp.route("/") +@bp.route("/index") +@auth_required +def index(): + return common.index() + + +@bp.route +@auth_required +def home(): + return common.home() + + +# ------------------------------- +# ------ Measurements and Logging +# ------------------------------- + + +@bp.before_request +@auth_required +def before_request(): + return app_common.before_request() + + +@bp.after_request +def after_request(response): + return app_common.after_request(response) + + +@bp.errorhandler(Exception) +def unhandled_exception(e): + return app_common.unhandled_exception(e) + + +@bp.errorhandler(cg_exceptions.ChunkedGraphAPIError) +def api_exception(e): + return app_common.api_exception(e) + + +## VALIDFRAGMENTS -------------------------------------------------------------- + + +@bp.route("///validfragments", methods=["POST", "GET"]) +@remap_public +@auth_requires_permission("view") +def handle_valid_frags(table_id, node_id): + return common.handle_valid_frags(table_id, node_id) + + +## MANIFEST -------------------------------------------------------------------- + + +@bp.route("//manifest/:0", methods=["GET"]) +@auth_requires_permission("view") +@remap_public +def handle_get_manifest(table_id, node_id): + return common.handle_get_manifest(table_id, node_id) diff --git a/pychunkedgraph/app/meshing/tasks.py b/pychunkedgraph/app/meshing/tasks.py new file mode 100644 index 000000000..a1f11ca68 --- /dev/null +++ b/pychunkedgraph/app/meshing/tasks.py @@ -0,0 +1,27 @@ +from pychunkedgraph.app import app_utils +from pychunkedgraph.meshing import meshgen, meshgen_utils +import numpy as np +import os + + +def remeshing(table_id, lvl2_nodes): + lvl2_nodes = np.array(lvl2_nodes, dtype=np.uint64) + cg = app_utils.get_cg(table_id, skip_cache=True) + + cv_mesh_dir = cg.meta.dataset_info["mesh"] + cv_unsharded_mesh_dir = cg.meta.dataset_info["mesh_metadata"]["unsharded_mesh_dir"] + cv_unsharded_mesh_path = os.path.join( + cg.meta.data_source.WATERSHED, cv_mesh_dir, cv_unsharded_mesh_dir + ) + mesh_data = cg.meta.custom_data["mesh"] + + # TODO: stop_layer and mip should be configurable by dataset + meshgen.remeshing( + cg, + lvl2_nodes, + stop_layer=mesh_data["max_layer"], + mip=mesh_data["mip"], + max_err=mesh_data["max_error"], + cv_sharded_mesh_dir=cv_mesh_dir, + cv_unsharded_mesh_path=cv_unsharded_mesh_path, + ) \ No newline at end of file diff --git a/pychunkedgraph/app/meshing/v1/routes.py b/pychunkedgraph/app/meshing/v1/routes.py new file mode 100644 index 000000000..dda067e90 --- /dev/null +++ b/pychunkedgraph/app/meshing/v1/routes.py @@ -0,0 +1,100 @@ +# pylint: disable=invalid-name, missing-docstring + +from flask import Blueprint +from middle_auth_client import auth_requires_permission, auth_required + +from pychunkedgraph.app import common as app_common +from pychunkedgraph.app.meshing import common +from pychunkedgraph.graph import exceptions as cg_exceptions +from pychunkedgraph.app.app_utils import get_cg +from pychunkedgraph.app.app_utils import remap_public + + +bp = Blueprint( + "pcg_meshing_v1", __name__, url_prefix=f"/{common.__meshing_url_prefix__}/api/v1" +) + +# ------------------------------- +# ------ Access control and index +# ------------------------------- + + +@bp.route("/") +@bp.route("/index") +@auth_required +def index(): + return common.index() + + +@bp.route +@auth_required +def home(): + return common.home() + + +# ------------------------------- +# ------ Measurements and Logging +# ------------------------------- + + +@bp.before_request +# @auth_required +def before_request(): + return app_common.before_request() + + +@bp.after_request +# @auth_required +def after_request(response): + return app_common.after_request(response) + + +@bp.errorhandler(Exception) +def unhandled_exception(e): + return app_common.unhandled_exception(e) + + +@bp.errorhandler(cg_exceptions.ChunkedGraphAPIError) +def api_exception(e): + return app_common.api_exception(e) + + +## VALIDFRAGMENTS -------------------------------------------------------------- + + +@bp.route("/table//node//validfragments", methods=["GET"]) +@auth_requires_permission("view") +@remap_public +def handle_valid_frags(table_id, node_id): + return common.handle_valid_frags(table_id, node_id) + + +## MANIFEST -------------------------------------------------------------------- + + +@bp.route("/table//manifest/:0", methods=["GET"]) +@auth_requires_permission( + "view", + public_table_key="table_id", + public_node_key="node_id", +) +@remap_public +def handle_get_manifest(table_id, node_id): + return common.handle_get_manifest(table_id, node_id) + + +## ENQUE MESHING JOBS ---------------------------------------------------------- + + +@bp.route("/table//remeshing", methods=["POST"]) +@auth_requires_permission("edit") +@remap_public(edit=True) +def handle_remesh(table_id): + return common.handle_remesh(table_id) + + +@bp.route("/table//clear_manifest_cache/", methods=["POST"]) +@auth_requires_permission("admin") +def handle_clear_manifest_cache(table_id, node_id): + cg = get_cg(table_id) + common.clear_manifest_cache(cg, node_id) diff --git a/pychunkedgraph/app/meshing_app_blueprint.py b/pychunkedgraph/app/meshing_app_blueprint.py deleted file mode 100644 index 0cabcdc32..000000000 --- a/pychunkedgraph/app/meshing_app_blueprint.py +++ /dev/null @@ -1,119 +0,0 @@ -from flask import Blueprint, request, make_response, jsonify, Response,\ - redirect, current_app -import json -import numpy as np - - -from pychunkedgraph.meshing import meshgen_utils, meshgen -from pychunkedgraph.app import app_utils -from pychunkedgraph.backend import chunkedgraph - -__version__ = 'fafb.1.21' -bp = Blueprint('pychunkedgraph_meshing', __name__, url_prefix="/meshing") - -# ------------------------------- -# ------ Access control and index -# ------------------------------- - -@bp.route('/') -@bp.route("/index") -def index(): - return "Meshing Server -- " + __version__ - - -@bp.route -def home(): - resp = make_response() - resp.headers['Access-Control-Allow-Origin'] = '*' - acah = "Origin, X-Requested-With, Content-Type, Accept" - resp.headers["Access-Control-Allow-Headers"] = acah - resp.headers["Access-Control-Allow-Methods"] = "POST, GET, OPTIONS" - resp.headers["Connection"] = "keep-alive" - return resp - - -def _remeshing(serialized_cg_info, lvl2_nodes): - cg = chunkedgraph.ChunkedGraph(**serialized_cg_info) - - # TODO: stop_layer and mip should be configurable by dataset - meshgen.remeshing(cg, lvl2_nodes, stop_layer=4, cv_path=None, - cv_mesh_dir=None, mip=1, max_err=40) - - return Response(status=200) - - -# ------------------------------------------------------------------------------ - -@bp.route('/1.0///mesh_preview', methods=['POST', 'GET']) -def handle_preview_meshes(table_id, node_id): - if len(request.data) > 0: - data = json.loads(request.data) - else: - data = {} - - node_id = np.uint64(node_id) - - cg = app_utils.get_cg(table_id) - - if "seg_ids" in data: - seg_ids = data["seg_ids"] - - chunk_id = cg.get_chunk_id(node_id) - supervoxel_ids = [cg.get_node_id(seg_id, chunk_id) - for seg_id in seg_ids] - else: - supervoxel_ids = None - - meshgen.mesh_lvl2_preview(cg, node_id, supervoxel_ids=supervoxel_ids, - cv_path=None, cv_mesh_dir=None, mip=2, - simplification_factor=999999, - max_err=40, parallel_download=1, verbose=True, - cache_control='no-cache') - return Response(status=200) - - -## VALIDFRAGMENTS -------------------------------------------------------------- - -@bp.route('/1.0///validfragments', methods=['POST', 'GET']) -def handle_valid_frags(table_id, node_id): - cg = app_utils.get_cg(table_id) - - seg_ids = meshgen_utils.get_highest_child_nodes_with_meshes( - cg, np.uint64(node_id), stop_layer=1, verify_existence=True) - - return app_utils.tobinary(seg_ids) - - -## MANIFEST -------------------------------------------------------------------- - -@bp.route('/1.0//manifest/:0', methods=['GET']) -def handle_get_manifest(table_id, node_id): - if len(request.data) > 0: - data = json.loads(request.data) - else: - data = {} - - if "start_layer" in data: - start_layer = int(data["start_layer"]) - else: - start_layer = None - - if "bounds" in request.args: - bounds = request.args["bounds"] - bounding_box = np.array([b.split("-") for b in bounds.split("_")], - dtype=np.int).T - else: - bounding_box = None - - verify = request.args.get('verify', False) - verify = verify in ['True', 'true', '1', True] - - cg = app_utils.get_cg(table_id) - - seg_ids = meshgen_utils.get_highest_child_nodes_with_meshes( - cg, np.uint64(node_id), stop_layer=2, start_layer=start_layer, - bounding_box=bounding_box, verify_existence=verify) - - filenames = [meshgen_utils.get_mesh_name(cg, s) for s in seg_ids] - - return jsonify(fragments=filenames) diff --git a/pychunkedgraph/backend/__init__.py b/pychunkedgraph/app/segmentation/__init__.py similarity index 100% rename from pychunkedgraph/backend/__init__.py rename to pychunkedgraph/app/segmentation/__init__.py diff --git a/pychunkedgraph/app/segmentation/common.py b/pychunkedgraph/app/segmentation/common.py new file mode 100644 index 000000000..9690d5c3d --- /dev/null +++ b/pychunkedgraph/app/segmentation/common.py @@ -0,0 +1,1279 @@ +# pylint: disable=invalid-name, missing-docstring + +import json +import os +import time +from datetime import datetime +from functools import reduce +from collections import deque, defaultdict + +import numpy as np +import pandas as pd +from flask import current_app, g, jsonify, make_response, request +from pytz import UTC + +from pychunkedgraph import __version__ +from pychunkedgraph.app import app_utils +from pychunkedgraph.graph import ( + attributes, + cutting, + segmenthistory, +) +from pychunkedgraph.graph import ( + edges as cg_edges, +) +from pychunkedgraph.graph import ( + exceptions as cg_exceptions, +) +from pychunkedgraph.graph.analysis import pathing +from pychunkedgraph.graph.attributes import OperationLogs +from pychunkedgraph.graph.misc import get_contact_sites +from pychunkedgraph.graph.operation import GraphEditOperation +from pychunkedgraph.graph.utils import basetypes +from pychunkedgraph.meshing import mesh_analysis + +__api_versions__ = [0, 1] +__segmentation_url_prefix__ = os.environ.get("SEGMENTATION_URL_PREFIX", "segmentation") + + +def index(): + return f"PyChunkedGraph Segmentation v{__version__}" + + +def home(): + resp = make_response() + resp.headers["Access-Control-Allow-Origin"] = "*" + acah = "Origin, X-Requested-With, Content-Type, Accept" + resp.headers["Access-Control-Allow-Headers"] = acah + resp.headers["Access-Control-Allow-Methods"] = "POST, GET, OPTIONS" + resp.headers["Connection"] = "keep-alive" + return resp + + +def _parse_timestamp( + arg_name, default_timestamp=0, return_datetime=False, allow_none=False +): + """Convert seconds since epoch to UTC datetime.""" + timestamp = request.args.get(arg_name, default_timestamp) + if timestamp is None: + if allow_none: + return None + else: + raise ( + cg_exceptions.BadRequest(f"Timestamp parameter {arg_name} is mandatory") + ) + try: + timestamp = float(timestamp) + if return_datetime: + return datetime.fromtimestamp(timestamp, UTC) + else: + return timestamp + except (TypeError, ValueError): + raise ( + cg_exceptions.BadRequest( + f"Timestamp parameter {arg_name} is not a valid unix timestamp" + ) + ) + + +def _get_bounds_from_request(request): + if "bounds" in request.args: + bounds = request.args["bounds"] + bounding_box = np.array([b.split("-") for b in bounds.split("_")], dtype=int).T + else: + bounding_box = None + return bounding_box + + +# ------------------- +# ------ Applications +# ------------------- + + +def sleep_me(sleep): + current_app.request_type = "sleep" + + time.sleep(sleep) + return "zzz... {} ... awake".format(sleep) + + +def handle_info(table_id): + cg = app_utils.get_cg(table_id) + dataset_info = cg.meta.dataset_info + app_info = {"app": {"supported_api_versions": list(__api_versions__)}} + combined_info = {**dataset_info, **app_info} + combined_info["sharded_mesh"] = True + combined_info["verify_mesh"] = cg.meta.custom_data.get("mesh", {}).get( + "verify", False + ) + mesh_dir = cg.meta.custom_data.get("mesh", {}).get("dir", None) + if mesh_dir is not None: + combined_info["mesh_dir"] = mesh_dir + elif combined_info.get("mesh_dir", None) is not None: + combined_info["mesh_dir"] = "graphene_meshes" + return jsonify(combined_info) + + +def handle_api_versions(): + return jsonify(__api_versions__) + + +def handle_version(): + return jsonify(__version__) + + +### GET ROOT ------------------------------------------------------------------- + + +def handle_root(table_id, atomic_id): + current_app.table_id = table_id + + # Convert seconds since epoch to UTC datetime + timestamp = _parse_timestamp("timestamp", time.time(), return_datetime=True) + + stop_layer = request.args.get("stop_layer", None) + if stop_layer is not None: + try: + stop_layer = int(stop_layer) + except (TypeError, ValueError) as e: + raise (cg_exceptions.BadRequest(f"stop_layer is not an integer {e}")) + + # Call ChunkedGraph + cg = app_utils.get_cg(table_id) + root_id = cg.get_root( + np.uint64(atomic_id), stop_layer=stop_layer, time_stamp=timestamp + ) + + # Return root ID + return root_id + + +### GET MINIMAL COVERING NODES -------------------------------------------------- + + +def handle_find_minimal_covering_nodes(table_id, is_binary=True): + if is_binary: + node_ids = np.frombuffer(request.data, np.uint64) + else: + node_ids = np.array(json.loads(request.data)["node_ids"], dtype=np.uint64) + + # Input parameters + timestamp = _parse_timestamp("timestamp", time.time(), return_datetime=True) + + # Initialize data structures + node_queue = defaultdict(set) + download_list = defaultdict(set) + + # Get initial layers for the provided node_ids + cg = app_utils.get_cg(table_id) + initial_layers = np.array([cg.get_chunk_layer(node_id) for node_id in node_ids]) + + # Populate node_queue with nodes grouped by their layers + for node_id, layer in zip(node_ids, initial_layers): + node_queue[layer].add(node_id) + + # find the minimum layer for the node_ids + min_layer = np.min(initial_layers) + min_children = cg.get_subgraph_nodes( + node_ids, return_layers=[min_layer], serializable=False, return_flattened=True + ) + # concatenate all the min_children together to one list from the dictionary + min_children = np.concatenate( + [min_children[node_id] for node_id in min_children.keys()] + ) + + # Process nodes from their layers + + for layer in range( + min_layer, cg.meta.layer_count + ): # Process from higher layers to lower layers + if len(node_queue[layer]) == 0: + continue + + current_nodes = list(node_queue[layer]) + + # Call handle_roots to find parents + parents = cg.get_roots( + current_nodes, stop_layer=layer + 1, time_stamp=timestamp + ) + unique_parents = np.unique(parents) + parent_layers = np.array( + [cg.get_chunk_layer(parent) for parent in unique_parents] + ) + + # Call handle_leaves_many to get leaves + leaves = cg.get_subgraph_nodes( + unique_parents, + return_layers=[min_layer], + serializable=False, + return_flattened=True, + ) + + # Process parents + for parent, parent_layer in zip(unique_parents, parent_layers): + child_mask = np.isin(leaves[parent], min_children) + if not np.all(child_mask): + # Call handle_children to fetch children + children = cg.get_children(parent) + + child_layers = np.array( + [cg.get_chunk_layer(child) for child in children] + ) + for child, child_layer in zip(children, child_layers): + if child in node_queue[child_layer]: + download_list[child_layer].add(child) + else: + node_queue[parent_layer].add(parent) + + # Clear the current layer's queue after processing + node_queue[layer].clear() + + # Return the download list + download_list = np.concatenate([np.array(list(v)) for v in download_list.values()]) + + return download_list + + +### GET ROOTS ------------------------------------------------------------------- + + +def handle_roots(table_id, is_binary=False): + current_app.request_type = "roots" + current_app.table_id = table_id + + if is_binary: + node_ids = np.frombuffer(request.data, np.uint64) + else: + node_ids = np.array(json.loads(request.data)["node_ids"], dtype=np.uint64) + # Convert seconds since epoch to UTC datetime + timestamp = _parse_timestamp("timestamp", time.time(), return_datetime=True) + + cg = app_utils.get_cg(table_id) + stop_layer = int(request.args.get("stop_layer", cg.meta.layer_count)) + is_root_layer = stop_layer == cg.meta.layer_count + assert_roots = bool(request.args.get("assert_roots", False)) + fail_to_zero = bool(request.args.get("fail_to_zero", True)) + root_ids = cg.get_roots( + node_ids, + stop_layer=stop_layer, + time_stamp=timestamp, + assert_roots=assert_roots and is_root_layer, + fail_to_zero=fail_to_zero, + ) + + return root_ids + + +### RANGE READ ------------------------------------------------------------------- + + +def handle_l2_chunk_children(table_id, chunk_id, as_array): + current_app.request_type = "l2_chunk_children" + current_app.table_id = table_id + + # Convert seconds since epoch to UTC datetime + timestamp = _parse_timestamp("timestamp", time.time(), return_datetime=True) + + # Call ChunkedGraph + cg = app_utils.get_cg(table_id) + + chunk_layer = cg.get_chunk_layer(chunk_id) + if chunk_layer != 2: + raise ( + cg_exceptions.PreconditionError( + f"This function only accepts level 2 chunks, the chunk requested is a level {chunk_layer} chunk" + ) + ) + + rr_chunk = cg.range_read_chunk( + chunk_id=np.uint64(chunk_id), + properties=attributes.Hierarchy.Child, + time_stamp=timestamp, + ) + + if as_array: + l2_chunk_array = [] + + for l2 in rr_chunk: + svs = rr_chunk[l2][0].value + for sv in svs: + l2_chunk_array.extend([l2, sv]) + + return np.array(l2_chunk_array) + else: + # store in dict of keys to arrays to remove reliance on bigtable + l2_chunk_dict = {} + for k in rr_chunk: + l2_chunk_dict[k] = rr_chunk[k][0].value + + return l2_chunk_dict + + +def str2bool(v): + return v.lower() in ("yes", "true", "t", "1") + + +def publish_edit( + table_id: str, + user_id: str, + result: GraphEditOperation.Result, + is_priority=True, + remesh: bool = True, +): + import pickle + + from messagingclient import MessagingClient + + attributes = { + "table_id": table_id, + "user_id": user_id, + "remesh_priority": "true" if is_priority else "false", + "remesh": "true" if remesh else "false", + } + payload = { + "operation_id": int(result.operation_id), + "new_lvl2_ids": result.new_lvl2_ids.tolist(), + "new_root_ids": result.new_root_ids.tolist(), + } + + exchange = os.getenv("PYCHUNKEDGRAPH_EDITS_EXCHANGE", "pychunkedgraph") + c = MessagingClient() + c.publish(exchange, pickle.dumps(payload), attributes) + + +### MERGE ---------------------------------------------------------------------- + + +def handle_merge(table_id, allow_same_segment_merge=False): + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + nodes = json.loads(request.data) + is_priority = request.args.get("priority", True, type=str2bool) + remesh = request.args.get("remesh", True, type=str2bool) + chebyshev_distance = request.args.get("chebyshev_distance", 3, type=int) + + current_app.logger.debug(nodes) + assert len(nodes) == 2 + + # Call ChunkedGraph + cg = app_utils.get_cg(table_id, skip_cache=True) + node_ids = [] + coords = [] + for node in nodes: + node_ids.append(node[0]) + coords.append(np.array(node[1:]) / cg.segmentation_resolution) + + atomic_edge = app_utils.handle_supervoxel_id_lookup(cg, coords, node_ids) + # Protection from long range mergers + chunk_coord_delta = cg.get_chunk_coordinates( + atomic_edge[0] + ) - cg.get_chunk_coordinates(atomic_edge[1]) + + if np.any(np.abs(chunk_coord_delta) > chebyshev_distance): + raise cg_exceptions.BadRequest( + "Chebyshev distance between merge points exceeded allowed maximum " + "(3 chunks)." + ) + + try: + ret = cg.add_edges( + user_id=user_id, + atomic_edges=np.array(atomic_edge, dtype=np.uint64), + source_coords=coords[:1], + sink_coords=coords[1:], + allow_same_segment_merge=allow_same_segment_merge, + ) + + except cg_exceptions.LockingError as e: + raise cg_exceptions.InternalServerError(e) + except cg_exceptions.PreconditionError as e: + raise cg_exceptions.BadRequest(str(e)) + + current_app.operation_id = ret.operation_id + if ret.new_root_ids is None: + raise cg_exceptions.InternalServerError( + "Could not merge selected " "supervoxel." + ) + + current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids)) + + if len(ret.new_lvl2_ids) > 0: + publish_edit(table_id, user_id, ret, is_priority=is_priority, remesh=remesh) + + return ret + + +### SPLIT ---------------------------------------------------------------------- + + +def handle_split(table_id): + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + data = json.loads(request.data) + is_priority = request.args.get("priority", True, type=str2bool) + remesh = request.args.get("remesh", True, type=str2bool) + mincut = request.args.get("mincut", True, type=str2bool) + + current_app.logger.debug(data) + + # Call ChunkedGraph + cg = app_utils.get_cg(table_id, skip_cache=True) + node_idents = [] + node_ident_map = { + "sources": 0, + "sinks": 1, + } + coords = [] + node_ids = [] + + for k in ["sources", "sinks"]: + for node in data[k]: + node_ids.append(node[0]) + coords.append(np.array(node[1:]) / cg.segmentation_resolution) + node_idents.append(node_ident_map[k]) + + node_ids = np.array(node_ids, dtype=np.uint64) + coords = np.array(coords) + node_idents = np.array(node_idents) + sv_ids = app_utils.handle_supervoxel_id_lookup(cg, coords, node_ids) + current_app.logger.debug( + {"node_id": node_ids, "sv_id": sv_ids, "node_ident": node_idents} + ) + + try: + ret = cg.remove_edges( + user_id=user_id, + source_ids=sv_ids[node_idents == 0], + sink_ids=sv_ids[node_idents == 1], + source_coords=coords[node_idents == 0], + sink_coords=coords[node_idents == 1], + mincut=mincut, + ) + except cg_exceptions.LockingError as e: + raise cg_exceptions.InternalServerError(e) + except cg_exceptions.PreconditionError as e: + raise cg_exceptions.BadRequest(str(e)) + + current_app.operation_id = ret.operation_id + if ret.new_root_ids is None: + raise cg_exceptions.InternalServerError( + "Could not split selected segment groups." + ) + + current_app.logger.debug(("after split:", ret.new_root_ids)) + current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids)) + + if len(ret.new_lvl2_ids) > 0: + publish_edit(table_id, user_id, ret, is_priority=is_priority, remesh=remesh) + + return ret + + +### UNDO ---------------------------------------------------------------------- + + +def handle_undo(table_id): + current_app.table_id = table_id + + data = json.loads(request.data) + is_priority = request.args.get("priority", True, type=str2bool) + remesh = request.args.get("remesh", True, type=str2bool) + user_id = str(g.auth_user.get("id", current_app.user_id)) + + current_app.logger.debug(data) + + # Call ChunkedGraph + cg = app_utils.get_cg(table_id) + operation_id = np.uint64(data["operation_id"]) + + try: + ret = cg.undo_operation(user_id=user_id, operation_id=operation_id) + except cg_exceptions.LockingError as e: + raise cg_exceptions.InternalServerError(e) + except (cg_exceptions.PreconditionError, cg_exceptions.PostconditionError) as e: + raise cg_exceptions.BadRequest(str(e)) + + current_app.logger.debug(("after undo:", ret.new_root_ids)) + current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids)) + + if ret.new_lvl2_ids.size > 0: + publish_edit(table_id, user_id, ret, is_priority=is_priority, remesh=remesh) + + return ret + + +### REDO ---------------------------------------------------------------------- + + +def handle_redo(table_id): + current_app.table_id = table_id + + data = json.loads(request.data) + is_priority = request.args.get("priority", True, type=str2bool) + remesh = request.args.get("remesh", True, type=str2bool) + user_id = str(g.auth_user.get("id", current_app.user_id)) + + current_app.logger.debug(data) + + # Call ChunkedGraph + cg = app_utils.get_cg(table_id) + operation_id = np.uint64(data["operation_id"]) + + try: + ret = cg.redo_operation(user_id=user_id, operation_id=operation_id) + except cg_exceptions.LockingError as e: + raise cg_exceptions.InternalServerError(e) + except (cg_exceptions.PreconditionError, cg_exceptions.PostconditionError) as e: + raise cg_exceptions.BadRequest(str(e)) + + current_app.logger.debug(("after redo:", ret.new_root_ids)) + current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids)) + + if ret.new_lvl2_ids.size > 0: + publish_edit(table_id, user_id, ret, is_priority=is_priority, remesh=remesh) + + return ret + + +### ROLLBACK USER -------------------------------------------------------------- + + +def handle_rollback(table_id): + current_app.table_id = table_id + + user_id = str(g.auth_user.get("id", current_app.user_id)) + target_user_id = request.args["user_id"] + + is_priority = request.args.get("priority", True, type=str2bool) + remesh = request.args.get("remesh", True, type=str2bool) + skip_operation_ids = np.array( + json.loads(request.args.get("skip_operation_ids", "[]")), dtype=np.uint64 + ) + + # Call ChunkedGraph + cg = app_utils.get_cg(table_id) + user_operations = all_user_operations(table_id) + operation_ids = user_operations["operation_id"] + timestamps = user_operations["timestamp"] + operations = list(zip(operation_ids, timestamps)) + operations.sort(key=lambda op: op[1], reverse=True) + + for operation in operations: + operation_id = operation[0] + if operation_id in skip_operation_ids: + continue + try: + ret = cg.undo_operation(user_id=target_user_id, operation_id=operation_id) + except cg_exceptions.LockingError: + raise cg_exceptions.InternalServerError( + "Could not acquire root lock for undo operation." + ) + except (cg_exceptions.PreconditionError, cg_exceptions.PostconditionError) as e: + raise cg_exceptions.BadRequest(str(e)) + + if ret.new_lvl2_ids.size > 0: + publish_edit(table_id, user_id, ret, is_priority=is_priority, remesh=remesh) + + return user_operations + + +### USER OPERATIONS ------------------------------------------------------------- + + +def all_user_operations( + table_id, include_undone=False, include_partial_splits=True, include_errored=True +): + # Gets all operations by the user. + # If include_undone is false, it filters to operations that are not undone. + # If the operation has been undone by anyone, it won't be returned here, + # unless it has been redone by anyone (and hasn't been undone again, etc.). + # The original user is considered to have "ownership" of the original edit, + # and that does not change even if someone else undoes/redoes that edit later. + # If include_partial_splits is false, it will not include splits that result + # in a single root ID (and so had no effect). + # If include_errored is false, it will not include operations that failed with + # an error. + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + target_user_id = request.args.get("user_id", None) + + start_time = _parse_timestamp("start_time", 0, return_datetime=True) + end_time = _parse_timestamp("end_time", datetime.utcnow(), return_datetime=True) + # Call ChunkedGraph + cg = app_utils.get_cg(table_id) + + log_rows = cg.client.read_log_entries( + start_time=start_time, end_time=end_time, user_id=target_user_id + ) + + valid_entry_ids = [] + timestamp_list = [] + undone_ids = np.array([]) + + entry_ids = np.sort(list(log_rows.keys())) + for entry_id in entry_ids: + entry = log_rows[entry_id] + user_id = entry[OperationLogs.UserID] + + should_check = ( + OperationLogs.Status not in entry + or entry[OperationLogs.Status] == OperationLogs.StatusCodes.SUCCESS.value + ) + + split_valid = ( + include_partial_splits + or (OperationLogs.AddedEdge in entry) + or (OperationLogs.RootID not in entry) + or (len(entry[OperationLogs.RootID]) > 1) + ) + if not split_valid: + print("excluding partial split", entry_id) + error_valid = include_errored or should_check + if not error_valid: + print("excluding errored", entry_id) + if user_id == target_user_id and split_valid and error_valid: + valid_entry_ids.append(entry_id) + timestamp = entry["timestamp"] + timestamp_list.append(timestamp) + + if should_check: + # if it is an undo of another operation, mark it as undone + if OperationLogs.UndoOperationID in entry: + undone_id = entry[OperationLogs.UndoOperationID] + undone_ids = np.append(undone_ids, undone_id) + + # if it is a redo of another operation, unmark it as undone + if OperationLogs.RedoOperationID in entry: + redone_id = entry[OperationLogs.RedoOperationID] + undone_ids = np.delete(undone_ids, np.argwhere(undone_ids == redone_id)) + + if include_undone: + return {"operation_id": valid_entry_ids, "timestamp": timestamp_list} + + filtered_entry_ids = [] + filtered_timestamp_list = [] + for i in range(len(valid_entry_ids)): + entry_id = valid_entry_ids[i] + entry = log_rows[entry_id] + + if ( + OperationLogs.UndoOperationID in entry + or OperationLogs.RedoOperationID in entry + ): + continue + + undone = entry_id in undone_ids + if not undone: + filtered_entry_ids.append(entry_id) + timestamp = entry["timestamp"] + filtered_timestamp_list.append(timestamp) + + return {"operation_id": filtered_entry_ids, "timestamp": filtered_timestamp_list} + + +### CHILDREN ------------------------------------------------------------------- + + +def handle_children(table_id, parent_id): + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + cg = app_utils.get_cg(table_id) + + parent_id = np.uint64(parent_id) + layer = cg.get_chunk_layer(parent_id) + + if layer > 1: + children = cg.get_children(parent_id) + else: + children = np.array([]) + + return children + + +### LEAVES --------------------------------------------------------------------- + + +def handle_leaves(table_id, root_id): + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + stop_layer = int(request.args.get("stop_layer", 1)) + + bounding_box = _get_bounds_from_request(request) + + cg = app_utils.get_cg(table_id) + if stop_layer > 1: + subgraph = cg.get_subgraph_nodes( + int(root_id), + bbox=bounding_box, + bbox_is_coordinate=True, + return_layers=[stop_layer], + return_flattened=True, + ) + + return subgraph + return cg.get_subgraph_leaves( + int(root_id), + bbox=bounding_box, + bbox_is_coordinate=True, + ) + + +### LEAVES OF MANY ROOTS --------------------------------------------------------------------- + + +def handle_leaves_many(table_id): + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + bounding_box = _get_bounds_from_request(request) + + node_ids = np.array(json.loads(request.data)["node_ids"], dtype=np.uint64) + stop_layer = int(request.args.get("stop_layer", 1)) + + # Call ChunkedGraph + cg = app_utils.get_cg(table_id) + + node_to_leaves_mapping = cg.get_subgraph_nodes( + node_ids, + bbox=bounding_box, + bbox_is_coordinate=True, + return_layers=[stop_layer], + serializable=True, + return_flattened=True, + ) + + return node_to_leaves_mapping + + +### LEAVES FROM LEAVES --------------------------------------------------------- + + +def handle_leaves_from_leave(table_id, atomic_id): + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + bounding_box = _get_bounds_from_request(request) + + # Call ChunkedGraph + cg = app_utils.get_cg(table_id) + root_id = cg.get_root(int(atomic_id)) + + atomic_ids = cg.get_subgraph( + root_id, bbox=bounding_box, bbox_is_coordinate=True, nodes_only=True + ) + + return np.concatenate([np.array([root_id]), atomic_ids]) + + +### SUBGRAPH ------------------------------------------------------------------- + + +def handle_subgraph(table_id, root_id): + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + bounding_box = _get_bounds_from_request(request) + + # Call ChunkedGraph + cg = app_utils.get_cg(table_id) + l2id_agglomeration_d, edges = cg.get_subgraph( + int(root_id), + bbox=bounding_box, + bbox_is_coordinate=True, + ) + edges = reduce(lambda x, y: x + y, edges, cg_edges.Edges([], [])) + supervoxels = np.concatenate( + [agg.supervoxels for agg in l2id_agglomeration_d.values()] + ) + mask0 = np.in1d(edges.node_ids1, supervoxels) + mask1 = np.in1d(edges.node_ids2, supervoxels) + edges = edges[mask0 & mask1] + + return edges + + +### CHANGE LOG ----------------------------------------------------------------- + + +def change_log(table_id, root_id=None, filtered=False): + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + time_stamp_past = _parse_timestamp("timestamp", 0, return_datetime=True) + + cg = app_utils.get_cg(table_id) + if not root_id: + return segmenthistory.get_all_log_entries(cg) + history = segmenthistory.SegmentHistory( + cg, [int(root_id)], timestamp_past=time_stamp_past + ) + return history.change_log_summary(filtered=filtered) + + +def tabular_change_log_recent(table_id): + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + start_time = _parse_timestamp("timestamp", 0, return_datetime=True) + end_time = ( + None + if request.args.get("timestamp_end", None) is None + else _parse_timestamp("timestamp_end", return_datetime=True) + ) + + # Call ChunkedGraph + cg = app_utils.get_cg(table_id) + + log_rows = cg.client.read_log_entries(start_time=start_time, end_time=end_time) + + timestamp_list = [] + user_list = [] + is_merge_list = [] + + operation_ids = np.sort(list(log_rows.keys())) + for operation_id in operation_ids: + operation = log_rows[operation_id] + + timestamp = operation["timestamp"] + timestamp_list.append(timestamp) + + user_id = operation[attributes.OperationLogs.UserID] + user_list.append(user_id) + + is_merge = attributes.OperationLogs.AddedEdge in operation + is_merge_list.append(is_merge) + + return pd.DataFrame.from_dict( + { + "operation_id": operation_ids, + "timestamp": timestamp_list, + "user_id": user_list, + "is_merge": is_merge_list, + } + ) + + +def tabular_change_logs(table_id, root_ids, filtered=False): + current_app.request_type = "tabular_changelog_many" + + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + # Call ChunkedGraph + cg = app_utils.get_cg(table_id) + history = segmenthistory.SegmentHistory( + cg, + root_ids, + ) + if filtered: + tab = history.tabular_changelogs_filtered + else: + tab = history.tabular_changelogs + + all_user_ids = [] + for tab_k in tab.keys(): + all_user_ids.extend(np.array(tab[tab_k]["user_id"]).reshape(-1)) + + all_user_ids = [] + for tab_k in tab.keys(): + all_user_ids.extend(np.array(tab[tab_k]["user_id"]).reshape(-1)) + + all_user_ids = np.unique(all_user_ids) + + if len(all_user_ids) == 0: + return tab + + user_name_dict, user_aff_dict = app_utils.get_userinfo_dict( + all_user_ids, current_app.config["AUTH_TOKEN"] + ) + + for tab_k in tab.keys(): + user_names = [ + user_name_dict.get(int(id_), "unknown") + for id_ in np.array(tab[tab_k]["user_id"]) + ] + user_affs = [ + user_aff_dict.get(int(id_), "unknown") + for id_ in np.array(tab[tab_k]["user_id"]) + ] + tab[tab_k]["user_name"] = user_names + tab[tab_k]["user_affiliation"] = user_affs + return tab + + +def merge_log(table_id, root_id): + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + # Call ChunkedGraph + cg = app_utils.get_cg(table_id) + hist = segmenthistory.SegmentHistory(cg, int(root_id)) + return hist.merge_log(correct_for_wrong_coord_type=False) + + +def handle_lineage_graph(table_id, root_id=None): + from networkx import node_link_data + + from pychunkedgraph.graph.lineage import lineage_graph + + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + timestamp_past = _parse_timestamp("timestamp_past", 0, return_datetime=True) + timestamp_future = _parse_timestamp( + "timestamp_future", time.time(), return_datetime=True + ) + + cg = app_utils.get_cg(table_id) + if root_id is None: + root_ids = np.array(json.loads(request.data)["root_ids"], dtype=np.uint64) + graph = lineage_graph(cg, root_ids, timestamp_past, timestamp_future) + return node_link_data(graph) + history_ids = segmenthistory.SegmentHistory( + cg, int(root_id), timestamp_past, timestamp_future + ) + return node_link_data(history_ids.lineage_graph) + + +def handle_past_id_mapping(table_id): + root_ids = np.array(json.loads(request.data)["root_ids"], dtype=np.uint64) + timestamp_past = _parse_timestamp( + "timestamp_past", default_timestamp=0, return_datetime=True + ) + timestamp_future = _parse_timestamp( + "timestamp_future", default_timestamp=time.time(), return_datetime=True + ) + + cg = app_utils.get_cg(table_id) + hist = segmenthistory.SegmentHistory( + cg, root_ids, timestamp_past=timestamp_past, timestamp_future=timestamp_future + ) + past_id_mapping, future_id_mapping = hist.past_future_id_mapping() + return { + "past_id_map": {str(k): past_id_mapping[k] for k in past_id_mapping.keys()}, + "future_id_map": { + str(k): future_id_mapping[k] for k in future_id_mapping.keys() + }, + } + + +def last_edit(table_id, root_id): + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + cg = app_utils.get_cg(table_id) + hist = segmenthistory.SegmentHistory(cg, int(root_id)) + return hist.last_edit_timestamp(int(root_id)) + + +def oldest_timestamp(table_id): + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + cg = app_utils.get_cg(table_id) + return cg.get_earliest_timestamp() + + +### CONTACT SITES -------------------------------------------------------------- + + +def handle_contact_sites(table_id, root_id): + partners = request.args.get("partners", True, type=app_utils.toboolean) + + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + timestamp = _parse_timestamp("timestamp", time.time(), return_datetime=True) + + bounding_box = _get_bounds_from_request(request) + + # Call ChunkedGraph + cg = app_utils.get_cg(table_id) + + cs_list, cs_metadata = get_contact_sites( + cg, + np.uint64(root_id), + bounding_box=bounding_box, + compute_partner=partners, + time_stamp=timestamp, + ) + + return cs_list, cs_metadata + + +def handle_pairwise_contact_sites(table_id, first_node_id, second_node_id): + current_app.request_type = "pairwise_contact_sites" + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + timestamp = _parse_timestamp("timestamp", time.time(), return_datetime=True) + + exact_location = request.args.get("exact_location", True, type=app_utils.toboolean) + cg = app_utils.get_cg(table_id) + contact_sites_list, cs_metadata = contact_sites.get_contact_sites_pairwise( + cg, + np.uint64(first_node_id), + np.uint64(second_node_id), + end_time=timestamp, + exact_location=exact_location, + ) + return contact_sites_list, cs_metadata + + +### SPLIT PREVIEW -------------------------------------------------------------- + + +def handle_split_preview(table_id): + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + data = json.loads(request.data) + current_app.logger.debug(data) + + cg = app_utils.get_cg(table_id) + node_idents = [] + node_ident_map = { + "sources": 0, + "sinks": 1, + } + coords = [] + node_ids = [] + + for k in ["sources", "sinks"]: + for node in data[k]: + node_ids.append(node[0]) + coords.append(np.array(node[1:]) / cg.segmentation_resolution) + node_idents.append(node_ident_map[k]) + + node_ids = np.array(node_ids, dtype=np.uint64) + coords = np.array(coords) + node_idents = np.array(node_idents) + sv_ids = app_utils.handle_supervoxel_id_lookup(cg, coords, node_ids) + current_app.logger.debug( + {"node_id": node_ids, "sv_id": sv_ids, "node_ident": node_idents} + ) + + try: + supervoxel_ccs, illegal_split = cutting.run_split_preview( + cg=cg, + source_ids=sv_ids[node_idents == 0], + sink_ids=sv_ids[node_idents == 1], + source_coords=coords[node_idents == 0], + sink_coords=coords[node_idents == 1], + bb_offset=(240, 240, 24), + ) + except cg_exceptions.PreconditionError as e: + raise cg_exceptions.BadRequest(str(e)) + + resp = { + "supervoxel_connected_components": supervoxel_ccs, + "illegal_split": illegal_split, + } + return resp + + +### FIND PATH -------------------------------------------------------------- + + +def handle_find_path(table_id, precision_mode): + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + nodes = json.loads(request.data) + current_app.logger.debug(nodes) + assert len(nodes) == 2 + + # Call ChunkedGraph + cg = app_utils.get_cg(table_id) + node_ids = [] + coords = [] + for node in nodes: + node_ids.append(node[0]) + coords.append(np.array(node[1:]) / cg.segmentation_resolution) + + if len(coords) != 2: + cg_exceptions.BadRequest("Merge needs two nodes.") + source_supervoxel_id, target_supervoxel_id = app_utils.handle_supervoxel_id_lookup( + cg, coords, node_ids + ) + + source_l2_id = cg.get_parent(source_supervoxel_id) + target_l2_id = cg.get_parent(target_supervoxel_id) + + print("Finding path...") + print(f"Source: {source_supervoxel_id}") + print(f"Target: {target_supervoxel_id}") + + root_time_stamp = cg.get_node_timestamps( + [np.uint64(nodes[0][0])], return_numpy=False + )[0] + l2_path = pathing.find_l2_shortest_path( + cg, source_l2_id, target_l2_id, time_stamp=root_time_stamp + ) + print(f"Path: {l2_path}") + if precision_mode: + centroids, failed_l2_ids = mesh_analysis.compute_mesh_centroids_of_l2_ids( + cg, l2_path, flatten=True + ) + print(f"Centroids: {centroids}") + print(f"Failed L2 ids: {failed_l2_ids}") + return { + "centroids_list": centroids, + "failed_l2_ids": failed_l2_ids, + "l2_path": l2_path, + } + else: + centroids = pathing.compute_rough_coordinate_path(cg, l2_path) + print(f"Centroids: {centroids}") + return {"centroids_list": centroids, "failed_l2_ids": [], "l2_path": l2_path} + + +### GET_LAYER2_SUBGRAPH +def handle_get_layer2_graph(table_id, node_id): + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + bounding_box = _get_bounds_from_request(request) + + cg = app_utils.get_cg(table_id) + print("Finding edge graph...") + edge_graph = pathing.get_lvl2_edge_list(cg, int(node_id), bbox=bounding_box) + print("Edge graph found len: {}".format(len(edge_graph))) + return {"edge_graph": edge_graph} + + +### ROOT INFO ---------------------------------------------------------------- + + +def handle_is_latest_roots(table_id, is_binary): + current_app.request_type = "is_latest_roots" + current_app.table_id = table_id + + if is_binary: + node_ids = np.frombuffer(request.data, np.uint64) + else: + node_ids = np.array(json.loads(request.data)["node_ids"], dtype=np.uint64) + # Convert seconds since epoch to UTC datetime + timestamp = _parse_timestamp("timestamp", time.time(), return_datetime=True) + + # Call ChunkedGraph + cg = app_utils.get_cg(table_id) + + return cg.is_latest_roots(node_ids, time_stamp=timestamp) + + +def _handle_latest(cg, node_ids, timestamp): + latest_mask = cg.is_latest_roots(node_ids, time_stamp=timestamp) + non_latest_ids = node_ids[~latest_mask] + row_dict = cg.client.read_nodes( + node_ids=non_latest_ids, + properties=attributes.Hierarchy.NewParent, + end_time=timestamp, + ) + + new_roots_ts = [] + for n in node_ids: + try: + v = row_dict[n] + new_roots_ts.append(v[-1].timestamp.timestamp()) # sorted descending + except KeyError: + ... + new_roots_ts = deque(new_roots_ts) + + result = [] + for x in latest_mask: + if x: + result.append(timestamp.timestamp()) + else: + result.append(new_roots_ts.popleft()) + return result + + +def handle_root_timestamps(table_id, is_binary, latest: bool = False): + current_app.request_type = "root_timestamps" + current_app.table_id = table_id + + if is_binary: + node_ids = np.frombuffer(request.data, np.uint64) + else: + node_ids = np.array(json.loads(request.data)["node_ids"], dtype=np.uint64) + + cg = app_utils.get_cg(table_id) + timestamp = _parse_timestamp("timestamp", time.time(), return_datetime=True) + if latest: + return _handle_latest(cg, node_ids, timestamp) + else: + timestamps = cg.get_node_timestamps(node_ids, return_numpy=False) + return [ts.timestamp() for ts in timestamps] + + +### OPERATION DETAILS ------------------------------------------------------------ + + +def operation_details(table_id): + from pychunkedgraph.export.operation_logs import parse_attr + + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + operation_ids = json.loads(request.args.get("operation_ids", "[]")) + + cg = app_utils.get_cg(table_id) + log_rows = cg.client.read_log_entries(operation_ids) + + result = {} + for k, v in log_rows.items(): + details = {} + for _k, _v in v.items(): + _k, _v = parse_attr(_k, _v) + try: + details[_k.decode("utf-8")] = _v + except AttributeError: + details[_k] = _v + result[int(k)] = details + return result + + +### DELTA ROOTS ------------------------------------------------------------ + + +def delta_roots(table_id): + current_app.table_id = table_id + + timestamp_past = _parse_timestamp("timestamp_past", None, return_datetime=True) + timestamp_future = _parse_timestamp( + "timestamp_future", time.time(), return_datetime=True + ) + cg = app_utils.get_cg(table_id) + old_roots, new_roots = cg.get_proofread_root_ids(timestamp_past, timestamp_future) + return {"old_roots": old_roots, "new_roots": new_roots} + + +### VALID NODES -------------------------------------------------------------- + + +def valid_nodes(table_id, is_binary): + current_app.request_type = "valid_nodes" + current_app.table_id = table_id + + if is_binary: + node_ids = np.frombuffer(request.data, np.uint64) + else: + node_ids = np.array(json.loads(request.data)["node_ids"], dtype=np.uint64) + + # Convert seconds since epoch to UTC datetime + end_timestamp = _parse_timestamp( + "end_timestamp", None, return_datetime=True, allow_none=True + ) + start_timestamp = _parse_timestamp( + "start_timestamp", None, return_datetime=True, allow_none=True + ) + + # Call ChunkedGraph + cg = app_utils.get_cg(table_id) + rows = cg.client.read_nodes( + node_ids=node_ids, start_time=start_timestamp, end_time=end_timestamp + ) + resp = {"valid_roots": np.array(list(rows.keys()), dtype=basetypes.NODE_ID)} + return resp diff --git a/pychunkedgraph/backend/utils/__init__.py b/pychunkedgraph/app/segmentation/generic/__init__.py similarity index 100% rename from pychunkedgraph/backend/utils/__init__.py rename to pychunkedgraph/app/segmentation/generic/__init__.py diff --git a/pychunkedgraph/app/segmentation/generic/routes.py b/pychunkedgraph/app/segmentation/generic/routes.py new file mode 100644 index 000000000..b67597c3d --- /dev/null +++ b/pychunkedgraph/app/segmentation/generic/routes.py @@ -0,0 +1,95 @@ +# pylint: disable=invalid-name, missing-docstring + +from flask import Blueprint +from middle_auth_client import ( + auth_required, + auth_requires_admin, + auth_requires_permission, +) + +from pychunkedgraph.app import common as app_common +from pychunkedgraph.app.app_utils import remap_public +from pychunkedgraph.app.segmentation import common +from pychunkedgraph.graph import exceptions as cg_exceptions + +bp = Blueprint( + "pcg_generic_v1", __name__, url_prefix=f"/{common.__segmentation_url_prefix__}" +) + + +# ------------------------------- +# ------ Access control and index +# ------------------------------- + + +@bp.route("/") +@bp.route("/index") +@auth_required +def index(): + return common.index() + + +@bp.route +@auth_required +def home(): + return common.home() + + +# ------------------------------- +# ------ Measurements and Logging +# ------------------------------- + + +@bp.before_request +def before_request(): + return app_common.before_request() + + +@bp.after_request +def after_request(response): + return app_common.after_request(response) + + +@bp.errorhandler(Exception) +def unhandled_exception(e): + return app_common.unhandled_exception(e) + + +@bp.errorhandler(cg_exceptions.ChunkedGraphAPIError) +def api_exception(e): + return app_common.api_exception(e) + + +# ------------------- +# ------ Applications +# ------------------- + + +@bp.route("/sleep/") +@auth_requires_admin +def sleep_me(sleep): + return common.sleep_me(sleep) + + +@bp.route("/table//info", methods=["GET"]) +@auth_requires_permission("view", public_table_key="table_id") +@remap_public +def handle_info(table_id): + return common.handle_info(table_id) + + +# ------------------- +# ------ API versions +# ------------------- + + +@bp.route("/api/versions", methods=["GET"]) +@auth_required +def handle_api_versions(): + return common.handle_api_versions() + + +@bp.route("/api/version", methods=["GET"]) +@auth_required +def handle_version(): + return common.handle_version() diff --git a/pychunkedgraph/app/segmentation/legacy/routes.py b/pychunkedgraph/app/segmentation/legacy/routes.py new file mode 100644 index 000000000..8ebc0475d --- /dev/null +++ b/pychunkedgraph/app/segmentation/legacy/routes.py @@ -0,0 +1,199 @@ +import json + +import numpy as np + +from flask import Blueprint, jsonify, request +from middle_auth_client import ( + auth_requires_admin, + auth_requires_permission, + auth_required, +) +from pychunkedgraph.app import app_utils +from pychunkedgraph.app import common as app_common +from pychunkedgraph.app.segmentation import common +from pychunkedgraph.graph import exceptions as cg_exceptions +from pychunkedgraph.app.app_utils import remap_public + +bp = Blueprint( + "pcg_segmentation_v0", + __name__, + url_prefix=f"/{common.__segmentation_url_prefix__}/1.0", +) + + +# ------------------------------- +# ------ Access control and index +# ------------------------------- + + +@bp.route("/") +@bp.route("/index") +@auth_required +def index(): + return common.index() + + +@bp.route +@auth_required +def home(): + return common.home() + + +# ------------------------------- +# ------ Measurements and Logging +# ------------------------------- + + +@bp.before_request +@auth_required +def before_request(): + return app_common.before_request() + + +@bp.after_request +@auth_required +def after_request(response): + return app_common.after_request(response) + + +@bp.errorhandler(Exception) +def unhandled_exception(e): + return app_common.unhandled_exception(e) + + +@bp.errorhandler(cg_exceptions.ChunkedGraphAPIError) +def api_exception(e): + return app_common.api_exception(e) + + +# ------------------- +# ------ Applications +# ------------------- + + +@bp.route("/sleep/") +@auth_requires_admin +def sleep_me(sleep): + return common.sleep_me(sleep) + + +@bp.route("//info", methods=["GET"]) +@auth_requires_permission("view") +@remap_public +def handle_info(table_id): + print("table_id", table_id) + return common.handle_info(table_id) + + +### MERGE ---------------------------------------------------------------------- + + +@bp.route("//graph/merge", methods=["POST", "GET"]) +@auth_requires_permission("edit") +def handle_merge(table_id): + merge_result = common.handle_merge(table_id) + return app_utils.tobinary(merge_result.new_root_ids) + + +### SPLIT ---------------------------------------------------------------------- + + +@bp.route("//graph/split", methods=["POST", "GET"]) +@auth_requires_permission("edit") +def handle_split(table_id): + split_result = common.handle_split(table_id) + return app_utils.tobinary(split_result.new_root_ids) + + +### GET ROOT ------------------------------------------------------------------- + + +@bp.route("//graph/root", methods=["POST", "GET"]) +@auth_requires_permission("view") +def handle_root_1(table_id): + atomic_id = np.uint64(json.loads(request.data)[0]) + root_id = common.handle_root(table_id, atomic_id) + return app_utils.tobinary(root_id) + + +@bp.route("//graph//root", methods=["POST", "GET"]) +@auth_requires_permission("view") +def handle_root_2(table_id, atomic_id): + root_id = common.handle_root(table_id, atomic_id) + return app_utils.tobinary(root_id) + + +### CHILDREN ------------------------------------------------------------------- + + +@bp.route("//segment//children", methods=["POST", "GET"]) +@auth_requires_permission("view") +def handle_children(table_id, parent_id): + children_ids = common.handle_children(table_id, parent_id) + return app_utils.tobinary(children_ids) + + +### LEAVES --------------------------------------------------------------------- + + +@bp.route("//segment//leaves", methods=["POST", "GET"]) +@auth_requires_permission("view") +def handle_leaves(table_id, root_id): + leaf_ids = common.handle_leaves(table_id, root_id) + return app_utils.tobinary(leaf_ids) + + +### LEAVES FROM LEAVES --------------------------------------------------------- + + +@bp.route("//segment//leaves_from_leave", methods=["POST", "GET"]) +@auth_requires_permission("view") +def handle_leaves_from_leave(table_id, atomic_id): + leaf_ids = common.handle_leaves_from_leave(table_id, atomic_id) + return app_utils.tobinary(leaf_ids) + + +### SUBGRAPH ------------------------------------------------------------------- + + +@bp.route("//segment//subgraph", methods=["POST", "GET"]) +@auth_requires_permission("view") +def handle_subgraph(table_id, root_id): + subgraph_result = common.handle_subgraph(table_id, root_id) + return app_utils.tobinary(subgraph_result) + + +### CONTACT SITES -------------------------------------------------------------- + + +@bp.route("//segment//contact_sites", methods=["POST", "GET"]) +@auth_requires_permission("view") +def handle_contact_sites(table_id, root_id): + contact_sites = common.handle_contact_sites(table_id, root_id) + return jsonify(contact_sites) + + +### CHANGE LOG ----------------------------------------------------------------- + + +@bp.route("//segment//change_log", methods=["POST", "GET"]) +@auth_requires_permission("view") +def change_log(table_id, root_id): + log = common.change_log(table_id, root_id) + return jsonify(log) + + +@bp.route("//segment//merge_log", methods=["POST", "GET"]) +@auth_requires_permission("view") +def merge_log(table_id, root_id): + log = common.merge_log(table_id, root_id) + return jsonify(log) + + +@bp.route("//graph/oldest_timestamp", methods=["POST", "GET"]) +@auth_requires_permission("view") +def oldest_timestamp(table_id): + delimiter = request.args.get("delimiter", " ") + earliest_timestamp = common.oldest_timestamp(table_id) + resp = {"iso": earliest_timestamp.isoformat(delimiter)} + return jsonify(resp) diff --git a/pychunkedgraph/app/segmentation/v1/routes.py b/pychunkedgraph/app/segmentation/v1/routes.py new file mode 100644 index 000000000..5d0920c66 --- /dev/null +++ b/pychunkedgraph/app/segmentation/v1/routes.py @@ -0,0 +1,647 @@ +# pylint: disable=invalid-name, missing-docstring + +import csv +import io +import json +import pickle + +import numpy as np +import pandas as pd +from flask import Blueprint, make_response, request +from middle_auth_client import ( + auth_required, + auth_requires_admin, + auth_requires_permission, +) + +from pychunkedgraph.app import common as app_common +from pychunkedgraph.app import app_utils +from pychunkedgraph.app.app_utils import ( + jsonify_with_kwargs, + remap_public, + tobinary, + toboolean, +) +from pychunkedgraph.app.segmentation import common +from pychunkedgraph.graph import exceptions as cg_exceptions + +bp = Blueprint( + "pcg_segmentation_v1", + __name__, + url_prefix=f"/{common.__segmentation_url_prefix__}/api/v1", +) + +# ------------------------------- +# ------ Access control and index +# ------------------------------- + + +@bp.route("/") +@bp.route("/index") +@auth_required +def index(): + return common.index() + + +@bp.route +@auth_required +def home(): + return common.home() + + +# ------------------------------- +# ------ Measurements and Logging +# ------------------------------- + + +@bp.before_request +def before_request(): + return app_common.before_request() + + +@bp.after_request +def after_request(response): + return app_common.after_request(response) + + +@bp.errorhandler(Exception) +def unhandled_exception(e): + return app_common.unhandled_exception(e) + + +@bp.errorhandler(cg_exceptions.ChunkedGraphAPIError) +def api_exception(e): + return app_common.api_exception(e) + + +### MERGE ---------------------------------------------------------------------- + + +@bp.route("/table//merge", methods=["POST"]) +@auth_requires_permission("edit") +@remap_public(edit=True) +def handle_merge(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + merge_result = common.handle_merge(table_id) + resp = { + "operation_id": merge_result.operation_id, + "new_root_ids": merge_result.new_root_ids, + } + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +@bp.route("/table//merge_admin", methods=["POST"]) +@auth_requires_permission("admin") +@remap_public(edit=True) +def handle_merge_admin(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + allow_same_segment_merge = request.args.get( + "allow_same_segment_merge", False, type=common.str2bool + ) + merge_result = common.handle_merge( + table_id, allow_same_segment_merge=allow_same_segment_merge + ) + resp = { + "operation_id": merge_result.operation_id, + "new_root_ids": merge_result.new_root_ids, + } + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +### SPLIT ---------------------------------------------------------------------- + + +@bp.route("/table//split", methods=["POST"]) +@auth_requires_permission("edit") +@remap_public(edit=True) +def handle_split(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + split_result = common.handle_split(table_id) + resp = { + "operation_id": split_result.operation_id, + "new_root_ids": split_result.new_root_ids, + } + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +@bp.route("/table//graph/split_preview", methods=["POST"]) +@auth_requires_permission("view") +@remap_public(edit=True) +def handle_split_preview(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + split_preview = common.handle_split_preview(table_id) + return jsonify_with_kwargs(split_preview, int64_as_str=int64_as_str) + + +### UNDO ---------------------------------------------------------------------- + + +@bp.route("/table//undo", methods=["POST"]) +@auth_requires_permission("edit") +@remap_public(edit=True) +def handle_undo(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + undo_result = common.handle_undo(table_id) + resp = { + "operation_id": undo_result.operation_id, + "new_root_ids": undo_result.new_root_ids, + } + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +### REDO ---------------------------------------------------------------------- + + +@bp.route("/table//redo", methods=["POST"]) +@auth_requires_permission("edit") +@remap_public(edit=True) +def handle_redo(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + redo_result = common.handle_redo(table_id) + resp = { + "operation_id": redo_result.operation_id, + "new_root_ids": redo_result.new_root_ids, + } + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +### ROLLBACK USER -------------------------------------------------------------- + + +@bp.route("/table//rollback_user", methods=["POST"]) +@auth_requires_admin +@remap_public(edit=True) +def handle_rollback(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + rollback_result = common.handle_rollback(table_id) + resp = rollback_result + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +### USER OPERATIONS ------------------------------------------------------------- + + +@bp.route("/table//user_operations", methods=["GET"]) +@auth_requires_permission("admin_view") +@remap_public(edit=True) +def handle_user_operations(table_id): + disp = request.args.get("disp", default=False, type=toboolean) + include_undone = request.args.get("include_undone", default=False, type=toboolean) + user_operations = pd.DataFrame.from_dict( + common.all_user_operations(table_id, include_undone) + ) + + if disp: + return user_operations.to_html() + else: + return user_operations.to_json() + + +### GET ROOT ------------------------------------------------------------------- + + +@bp.route("/table//node//root", methods=["GET"]) +@auth_requires_permission("view") +@remap_public +def handle_root(table_id, node_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + root_id = common.handle_root(table_id, node_id) + resp = {"root_id": root_id} + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +### GET ROOTS ------------------------------------------------------------------ + + +@bp.route("/table//roots", methods=["POST"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def handle_roots(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + root_ids = common.handle_roots(table_id, is_binary=False) + resp = {"root_ids": root_ids} + + arg_as_binary = request.args.get("as_binary", default="", type=str) + if arg_as_binary in resp: + return tobinary(resp[arg_as_binary]) + else: + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +### GET ROOTS BINARY ----------------------------------------------------------- + + +@bp.route("/table//roots_binary", methods=["POST"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def handle_roots_binary(table_id): + root_ids = common.handle_roots(table_id, is_binary=True) + return tobinary(root_ids) + + +### CHILDREN ------------------------------------------------------------------- + + +@bp.route("/table//node//children", methods=["GET"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def handle_children(table_id, node_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + children_ids = common.handle_children(table_id, node_id) + resp = {"children_ids": children_ids} + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +### GET L2:SV MAPPINGS OF A L2 CHUNK ------------------------------------------------------------------ + + +@bp.route("/table//l2_chunk_children/", methods=["GET"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def handle_l2_chunk_children(table_id, chunk_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + as_array = request.args.get("as_array", default=False, type=toboolean) + l2_chunk_children = common.handle_l2_chunk_children(table_id, chunk_id, as_array) + if as_array: + resp = {"l2_chunk_children": l2_chunk_children} + else: + resp = {"l2_chunk_children": pickle.dumps(l2_chunk_children)} + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +### GET L2:SV MAPPINGS OF A L2 CHUNK BINARY ------------------------------------------------------------------ + + +@bp.route("/table//l2_chunk_children_binary/", methods=["GET"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def handle_l2_chunk_children_binary(table_id, chunk_id): + as_array = request.args.get("as_array", default=False, type=toboolean) + l2_chunk_children = common.handle_l2_chunk_children(table_id, chunk_id, as_array) + if as_array: + return tobinary(l2_chunk_children) + else: + return pickle.dumps(l2_chunk_children) + + +### LEAVES --------------------------------------------------------------------- + + +@bp.route("/table//node//leaves", methods=["GET"]) +@auth_requires_permission( + "view", + public_table_key="table_id", + public_node_key="node_id", +) +@remap_public(edit=False) +def handle_leaves(table_id, node_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + leaf_ids = common.handle_leaves(table_id, node_id) + resp = {"leaf_ids": leaf_ids} + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +## LEAVES OF MANY ROOTS + + +@bp.route("/table//node/leaves_many", methods=["POST"]) +@bp.route("/table//leaves_many", methods=["POST"]) +@auth_requires_permission( + "view", + public_table_key="table_id", + public_node_key="node_id", +) +@remap_public(check_node_ids=True) +def handle_leaves_many(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + root_to_leaf_dict = common.handle_leaves_many(table_id) + return jsonify_with_kwargs(root_to_leaf_dict, int64_as_str=int64_as_str) + + +### GET MINIMAL COVERING NODES + + +@bp.route("/table//minimal_covering_nodes", methods=["POST"]) +@auth_requires_permission("view", public_table_key="table_id") +@remap_public(check_node_ids=False) +def handle_minimal_covering_nodes(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + is_binary = request.args.get("is_binary", default=False, type=toboolean) + + covering_nodes = common.handle_find_minimal_covering_nodes(table_id, is_binary=is_binary) + as_array = request.args.get("as_array", default=False, type=toboolean) + if as_array: + return tobinary(covering_nodes) + return jsonify_with_kwargs(covering_nodes, int64_as_str=int64_as_str) + + +### SUBGRAPH ------------------------------------------------------------------- + + +@bp.route("/table//node//subgraph", methods=["GET"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def handle_subgraph(table_id, node_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + edges = common.handle_subgraph(table_id, node_id) + resp = { + "nodes": edges.get_pairs(), + "affinities": edges.affinities, + "areas": edges.areas, + } + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +### CONTACT SITES -------------------------------------------------------------- + + +@bp.route("/table//node//contact_sites", methods=["GET"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def handle_contact_sites(table_id, node_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + contact_sites, contact_site_metadata = common.handle_contact_sites( + table_id, node_id + ) + resp = { + "contact_sites": contact_sites, + "contact_site_metadata": contact_site_metadata, + } + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +@bp.route( + "/table//node/contact_sites_pair//", + methods=["GET"], +) +@auth_requires_permission("view") +@remap_public(edit=False) +def handle_pairwise_contact_sites(table_id, first_node_id, second_node_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + contact_sites, contact_site_metadata = common.handle_pairwise_contact_sites( + table_id, first_node_id, second_node_id + ) + resp = { + "contact_sites": contact_sites, + "contact_site_metadata": contact_site_metadata, + } + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +### CHANGE LOG ----------------------------------------------------------------- + + +@bp.route("/table//change_log", methods=["GET"]) +@auth_requires_admin +@remap_public(edit=False) +def change_log_full(table_id): + si = io.StringIO() + cw = csv.writer(si) + log_entries = common.change_log(table_id) + cw.writerow(["user_id", "action", "root_ids", "timestamp"]) + cw.writerows(log_entries) + output = make_response(si.getvalue()) + output.headers["Content-Disposition"] = f"attachment; filename={table_id}.csv" + output.headers["Content-type"] = "text/csv" + return output + + +@bp.route("/table//tabular_change_log_recent", methods=["GET"]) +@auth_requires_permission("admin_view") +@remap_public(edit=True) +def tabular_change_log_weekly(table_id): + disp = request.args.get("disp", default=False, type=toboolean) + weekly_tab_change_log = common.tabular_change_log_recent(table_id) + + if disp: + return weekly_tab_change_log.to_html() + else: + return weekly_tab_change_log.to_json() + + +@bp.route("/table//root//change_log", methods=["GET"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def change_log(table_id, root_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + filtered = request.args.get("filtered", default=False, type=toboolean) + log = common.change_log(table_id, root_id, filtered) + return jsonify_with_kwargs(log, int64_as_str=int64_as_str) + + +@bp.route("/table//root//tabular_change_log", methods=["GET"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def tabular_change_log(table_id, root_id): + disp = request.args.get("disp", default=False, type=toboolean) + # get_root_ids = request.args.get("root_ids", default=False, type=toboolean) + filtered = request.args.get("filtered", default=True, type=toboolean) + tab_change_log_dict = common.tabular_change_logs(table_id, [int(root_id)], filtered) + tab_change_log = tab_change_log_dict[int(root_id)] + if disp: + return tab_change_log.to_html() + return tab_change_log.to_json() + + +@bp.route("/table//tabular_change_log_many", methods=["GET"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def tabular_change_log_many(table_id): + filtered = request.args.get("filtered", default=True, type=toboolean) + root_ids = np.array(json.loads(request.data)["root_ids"], dtype=np.uint64) + tab_change_log_dict = common.tabular_change_logs(table_id, root_ids, filtered) + + return jsonify_with_kwargs( + {str(k): tab_change_log_dict[k] for k in tab_change_log_dict.keys()} + ) + + +@bp.route("/table//root//merge_log", methods=["GET"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def merge_log(table_id, root_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + log = common.merge_log(table_id, root_id) + return jsonify_with_kwargs(log, int64_as_str=int64_as_str) + + +@bp.route("/table//root//lineage_graph", methods=["GET"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def handle_lineage_graph(table_id, root_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + resp = common.handle_lineage_graph(table_id, root_id) + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +@bp.route("/table//lineage_graph_multiple", methods=["POST"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def handle_lineage_graph_multiple(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + resp = common.handle_lineage_graph(table_id) + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +@bp.route("/table//past_id_mapping", methods=["GET"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def handle_past_id_mapping(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + resp = common.handle_past_id_mapping(table_id) + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +@bp.route("/table//oldest_timestamp", methods=["GET"]) +@auth_requires_permission( + "view", + public_table_key="table_id", +) +@remap_public(edit=False) +def oldest_timestamp(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + delimiter = request.args.get("delimiter", default=" ", type=str) + earliest_timestamp = common.oldest_timestamp(table_id) + resp = {"iso": earliest_timestamp.isoformat(delimiter)} + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +@bp.route("/table//root//last_edit", methods=["GET"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def last_edit(table_id, root_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + delimiter = request.args.get("delimiter", default=" ", type=str) + latest_timestamp = common.last_edit(table_id, root_id) + resp = {"iso": latest_timestamp.isoformat(delimiter)} + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +### FIND PATH ------------------------------------------------------------------ + + +@bp.route("/table//graph/find_path", methods=["POST"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def find_path(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + precision_mode = request.args.get("precision_mode", default=True, type=toboolean) + find_path_result = common.handle_find_path(table_id, precision_mode) + return jsonify_with_kwargs(find_path_result, int64_as_str=int64_as_str) + + +### ROOT INFO ----------------------------------------------------------------- + + +@bp.route("/table//is_latest_roots", methods=["POST"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def handle_is_latest_roots(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + is_latest_roots = common.handle_is_latest_roots(table_id, is_binary=False) + resp = {"is_latest": is_latest_roots} + + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +@bp.route("/table//root_timestamps", methods=["POST"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def handle_root_timestamps(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + is_binary = request.args.get("is_binary", default=False, type=toboolean) + latest = request.args.get("latest", default=False, type=toboolean) + is_binary = request.args.get("is_binary", default=False, type=toboolean) + root_timestamps = common.handle_root_timestamps( + table_id, is_binary=is_binary, latest=latest + ) + resp = {"timestamp": root_timestamps} + + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +## Lookup root id from coordinate ----------------------------------------------- + + +@bp.route("/table//roots_from_coords", methods=["POST"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def handle_roots_from_coords(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + resp = common.handle_roots_from_coord(table_id) + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +## Get level2 graph ------------------------------------------------------------- +@bp.route("/table//node//lvl2_graph", methods=["GET"]) +@auth_requires_permission( + "view", + public_table_key="table_id", + public_node_key="node_id", +) +@remap_public(edit=False) +def handle_get_lvl2_graph(table_id, node_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + resp = common.handle_get_layer2_graph(table_id, node_id) + out = jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + if "bounds" in request.args: + out.headers["Used-Bounds"] = True + else: + out.headers["Used-Bounds"] = False + return out + + +### GET OPERATION DETAILS -------------------------------------------------------- + + +@bp.route("/table//operation_details", methods=["GET"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def operation_details(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + resp = common.operation_details(table_id) + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +### GET PROOFREAD IDS -------------------------------------------------------- + + +@bp.route("/table//delta_roots", methods=["GET"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def delta_roots(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + resp = common.delta_roots(table_id) + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +### GET VALID NODES ------------------------------------------------------------- + + +@bp.route("/table//valid_nodes", methods=["GET"]) +@auth_requires_permission("view") +@remap_public(edit=False) +def valid_nodes(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + is_binary = request.args.get("is_binary", default=False, type=toboolean) + resp = common.valid_nodes(table_id, is_binary=is_binary) + + return jsonify_with_kwargs(resp, int64_as_str=int64_as_str) + + +@bp.route("/table//supervoxel_lookup", methods=["POST"]) +@auth_requires_permission("admin") +@remap_public(edit=False) +def handle_supervoxel_lookup(table_id): + int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean) + + nodes = json.loads(request.data) + cg = app_utils.get_cg(table_id) + node_ids = [] + coords = [] + for node in nodes: + node_ids.append(node[0]) + coords.append(np.array(node[1:]) / cg.segmentation_resolution) + + atomic_ids = app_utils.handle_supervoxel_id_lookup(cg, coords, node_ids) + return jsonify_with_kwargs(atomic_ids, int64_as_str=int64_as_str) diff --git a/pychunkedgraph/backend/chunkedgraph.py b/pychunkedgraph/backend/chunkedgraph.py deleted file mode 100644 index c4662b125..000000000 --- a/pychunkedgraph/backend/chunkedgraph.py +++ /dev/null @@ -1,3648 +0,0 @@ -import collections -import numpy as np -import time -import datetime -import os -import sys -import networkx as nx -import pytz -import cloudvolume -import re -import itertools -import logging - -from itertools import chain -from multiwrapper import multiprocessing_utils as mu -from pychunkedgraph.backend import cutting, chunkedgraph_comp, flatgraph_utils -from pychunkedgraph.backend.chunkedgraph_utils import compute_indices_pandas, \ - compute_bitmasks, get_google_compatible_time_stamp, \ - get_time_range_filter, get_time_range_and_column_filter, get_max_time, \ - combine_cross_chunk_edge_dicts, get_min_time, partial_row_data_to_column_dict -from pychunkedgraph.backend.utils import serializers, column_keys, row_keys, basetypes -from pychunkedgraph.backend import chunkedgraph_exceptions as cg_exceptions, \ - chunkedgraph_edits as cg_edits -from pychunkedgraph.backend.graphoperation import ( - GraphEditOperation, - MergeOperation, - MulticutOperation, - SplitOperation, -) -# from pychunkedgraph.meshing import meshgen - -from google.api_core.retry import Retry, if_exception_type -from google.api_core.exceptions import Aborted, DeadlineExceeded, \ - ServiceUnavailable -from google.auth import credentials -from google.cloud import bigtable -from google.cloud.bigtable.row_filters import TimestampRange, \ - TimestampRangeFilter, ColumnRangeFilter, ValueRangeFilter, RowFilterChain, \ - ColumnQualifierRegexFilter, RowFilterUnion, ConditionalRowFilter, \ - PassAllFilter, RowFilter, RowKeyRegexFilter, FamilyNameRegexFilter -from google.cloud.bigtable.row_set import RowSet -from google.cloud.bigtable.column_family import MaxVersionsGCRule - -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, NamedTuple - - -HOME = os.path.expanduser("~") -N_DIGITS_UINT64 = len(str(np.iinfo(np.uint64).max)) -N_BITS_PER_ROOT_COUNTER = np.uint64(8) -LOCK_EXPIRED_TIME_DELTA = datetime.timedelta(minutes=3, seconds=0) -UTC = pytz.UTC - -# Setting environment wide credential path -os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = \ - HOME + "/.cloudvolume/secrets/google-secret.json" - - -class ChunkedGraph(object): - def __init__(self, - table_id: str, - instance_id: str = "pychunkedgraph", - project_id: str = "neuromancer-seung-import", - chunk_size: Tuple[np.uint64, np.uint64, np.uint64] = None, - fan_out: Optional[np.uint64] = None, - use_skip_connections: Optional[bool] = True, - s_bits_atomic_layer: Optional[np.uint64] = 8, - n_bits_root_counter: Optional[np.uint64] = 0, - n_layers: Optional[np.uint64] = None, - credentials: Optional[credentials.Credentials] = None, - client: bigtable.Client = None, - dataset_info: Optional[object] = None, - is_new: bool = False, - logger: Optional[logging.Logger] = None) -> None: - - if logger is None: - self.logger = logging.getLogger(f"{project_id}/{instance_id}/{table_id}") - self.logger.setLevel(logging.WARNING) - if not self.logger.handlers: - sh = logging.StreamHandler(sys.stdout) - sh.setLevel(logging.WARNING) - self.logger.addHandler(sh) - else: - self.logger = logger - - if client is not None: - self._client = client - else: - self._client = bigtable.Client(project=project_id, admin=True, - credentials=credentials) - - self._instance = self.client.instance(instance_id) - self._table_id = table_id - - self._table = self.instance.table(self.table_id) - - if is_new: - self._check_and_create_table() - - self._dataset_info = self.check_and_write_table_parameters( - column_keys.GraphSettings.DatasetInfo, dataset_info, - required=True, is_new=is_new) - - self._cv_path = self._dataset_info["data_dir"] # required - self._mesh_dir = self._dataset_info.get("mesh", None) # optional - - self._n_layers = self.check_and_write_table_parameters( - column_keys.GraphSettings.LayerCount, n_layers, - required=True, is_new=is_new) - self._fan_out = self.check_and_write_table_parameters( - column_keys.GraphSettings.FanOut, fan_out, - required=True, is_new=is_new) - s_bits_atomic_layer = self.check_and_write_table_parameters( - column_keys.GraphSettings.SpatialBits, - np.uint64(s_bits_atomic_layer), - required=False, is_new=is_new) - self._use_skip_connections = self.check_and_write_table_parameters( - column_keys.GraphSettings.SkipConnections, - np.uint64(use_skip_connections), required=False, is_new=is_new) > 0 - self._n_bits_root_counter = self.check_and_write_table_parameters( - column_keys.GraphSettings.RootCounterBits, - np.uint64(n_bits_root_counter), - required=False, is_new=is_new) - self._chunk_size = self.check_and_write_table_parameters( - column_keys.GraphSettings.ChunkSize, chunk_size, - required=True, is_new=is_new) - - self._dataset_info["graph"] = {"chunk_size": self.chunk_size} - - self._bitmasks = compute_bitmasks(self.n_layers, self.fan_out, - s_bits_atomic_layer) - - self._cv = None - - # Hardcoded parameters - self._n_bits_for_layer_id = 8 - self._cv_mip = 0 - - # Vectorized calls - self._get_chunk_layer_vec = np.vectorize(self.get_chunk_layer) - self._get_chunk_id_vec = np.vectorize(self.get_chunk_id) - - @property - def client(self) -> bigtable.Client: - return self._client - - @property - def instance(self) -> bigtable.instance.Instance: - return self._instance - - @property - def table(self) -> bigtable.table.Table: - return self._table - - @property - def table_id(self) -> str: - return self._table_id - - @property - def instance_id(self): - return self.instance.instance_id - - @property - def project_id(self): - return self.client.project - - @property - def family_id(self) -> str: - return "0" - - @property - def incrementer_family_id(self) -> str: - return "1" - - @property - def log_family_id(self) -> str: - return "2" - - @property - def cross_edge_family_id(self) -> str: - return "3" - - @property - def family_ids(self): - return [self.family_id, self.incrementer_family_id, self.log_family_id, - self.cross_edge_family_id] - - @property - def fan_out(self) -> np.uint64: - return self._fan_out - - @property - def chunk_size(self) -> np.ndarray: - return self._chunk_size - - @property - def n_bits_root_counter(self) -> np.ndarray: - return self._n_bits_root_counter - - @property - def use_skip_connections(self) -> np.ndarray: - return self._use_skip_connections - - @property - def segmentation_chunk_size(self) -> np.ndarray: - return self.cv.scale["chunk_sizes"][0] - - @property - def segmentation_resolution(self) -> np.ndarray: - return np.array(self.cv.scale["resolution"]) - - @property - def segmentation_bounds(self) -> np.ndarray: - return np.array(self.cv.bounds.to_list()).reshape(2, 3) - - @property - def n_layers(self) -> int: - return int(self._n_layers) - - @property - def bitmasks(self) -> Dict[int, int]: - return self._bitmasks - - @property - def cv_mesh_path(self) -> str: - return "%s/%s" % (self._cv_path, self._mesh_dir) - - @property - def dataset_info(self) -> object: - return self._dataset_info - - @property - def cv_mip(self) -> int: - return self._cv_mip - - @property - def cv(self) -> cloudvolume.CloudVolume: - if self._cv is None: - self._cv = cloudvolume.CloudVolume(self._cv_path, mip=self._cv_mip, - info=self.dataset_info) - - return self._cv - - @property - def vx_vol_bounds(self): - return np.array(self.cv.bounds.to_list()).reshape(2, -1).T - - @property - def root_chunk_id(self): - return self.get_chunk_id(layer=int(self.n_layers), x=0, y=0, z=0) - - def _check_and_create_table(self) -> None: - """ Checks if table exists and creates new one if necessary """ - table_ids = [t.table_id for t in self.instance.list_tables()] - - if not self.table_id in table_ids: - self.table.create() - f = self.table.column_family(self.family_id) - f.create() - - f_inc = self.table.column_family(self.incrementer_family_id, - gc_rule=MaxVersionsGCRule(1)) - f_inc.create() - - f_log = self.table.column_family(self.log_family_id) - f_log.create() - - f_ce = self.table.column_family(self.cross_edge_family_id, - gc_rule=MaxVersionsGCRule(1)) - f_ce.create() - - self.logger.info(f"Table {self.table_id} created") - - def check_and_write_table_parameters(self, column: column_keys._Column, - value: Optional[Union[str, np.uint64]] = None, - required: bool = True, - is_new: bool = False - ) -> Union[str, np.uint64]: - """ Checks if a parameter already exists in the table. If it already - exists it returns the stored value, else it stores the given value. - Storing the given values can be enforced with `is_new`. The function - raises an exception if no value is passed and the parameter does not - exist, yet. - - :param column: column_keys._Column - :param value: Union[str, np.uint64] - :param required: bool - :param is_new: bool - :return: Union[str, np.uint64] - value - """ - setting = self.read_byte_row(row_key=row_keys.GraphSettings, - columns=column) - - if (not setting or is_new) and value is not None: - row = self.mutate_row(row_keys.GraphSettings, {column: value}) - self.bulk_write([row]) - elif not setting and value is None: - assert not required - return None - else: - value = setting[0].value - - return value - - def is_in_bounds(self, coordinate: Sequence[int]): - """ Checks whether a coordinate is within the segmentation bounds - - :param coordinate: [int, int, int] - :return bool - """ - coordinate = np.array(coordinate) - - if np.any(coordinate < self.segmentation_bounds[0]): - return False - elif np.any(coordinate > self.segmentation_bounds[1]): - return False - else: - return True - - def get_serialized_info(self): - """ Rerturns dictionary that can be used to load this ChunkedGraph - - :return: dict - """ - info = {"table_id": self.table_id, - "instance_id": self.instance_id, - "project_id": self.project_id} - - try: - info["credentials"] = self.client.credentials - except: - info["credentials"] = self.client._credentials - - return info - - - def adjust_vol_coordinates_to_cv(self, x: np.int, y: np.int, z: np.int, - resolution: Sequence[np.int]): - resolution = np.array(resolution) - scaling = np.array(self.cv.resolution / resolution, dtype=np.int) - - x = (x / scaling[0] - self.vx_vol_bounds[0, 0]) - y = (y / scaling[1] - self.vx_vol_bounds[1, 0]) - z = (z / scaling[2] - self.vx_vol_bounds[2, 0]) - - return np.array([x, y, z]) - - def get_chunk_coordinates_from_vol_coordinates(self, - x: np.int, - y: np.int, - z: np.int, - resolution: Sequence[np.int], - ceil: bool = False, - layer: int = 1 - ) -> np.ndarray: - """ Translates volume coordinates to chunk_coordinates - - :param x: np.int - :param y: np.int - :param z: np.int - :param resolution: np.ndarray - :param ceil bool - :param layer: int - :return: - """ - resolution = np.array(resolution) - scaling = np.array(self.cv.resolution / resolution, dtype=np.int) - - x = (x / scaling[0] - self.vx_vol_bounds[0, 0]) / self.chunk_size[0] - y = (y / scaling[1] - self.vx_vol_bounds[1, 0]) / self.chunk_size[1] - z = (z / scaling[2] - self.vx_vol_bounds[2, 0]) / self.chunk_size[2] - - x /= self.fan_out ** (max(layer - 2, 0)) - y /= self.fan_out ** (max(layer - 2, 0)) - z /= self.fan_out ** (max(layer - 2, 0)) - - coords = np.array([x, y, z]) - if ceil: - coords = np.ceil(coords) - - return coords.astype(np.int) - - def get_chunk_layer(self, node_or_chunk_id: np.uint64) -> int: - """ Extract Layer from Node ID or Chunk ID - - :param node_or_chunk_id: np.uint64 - :return: int - """ - return int(int(node_or_chunk_id) >> 64 - self._n_bits_for_layer_id) - - def get_chunk_layers(self, node_or_chunk_ids: Sequence[np.uint64] - ) -> np.ndarray: - """ Extract Layers from Node IDs or Chunk IDs - - :param node_or_chunk_ids: np.ndarray - :return: np.ndarray - """ - if len(node_or_chunk_ids) == 0: - return np.array([], dtype=np.int) - - return self._get_chunk_layer_vec(node_or_chunk_ids) - - def get_chunk_coordinates(self, node_or_chunk_id: np.uint64 - ) -> np.ndarray: - """ Extract X, Y and Z coordinate from Node ID or Chunk ID - - :param node_or_chunk_id: np.uint64 - :return: Tuple(int, int, int) - """ - layer = self.get_chunk_layer(node_or_chunk_id) - bits_per_dim = self.bitmasks[layer] - - x_offset = 64 - self._n_bits_for_layer_id - bits_per_dim - y_offset = x_offset - bits_per_dim - z_offset = y_offset - bits_per_dim - - x = int(node_or_chunk_id) >> x_offset & 2 ** bits_per_dim - 1 - y = int(node_or_chunk_id) >> y_offset & 2 ** bits_per_dim - 1 - z = int(node_or_chunk_id) >> z_offset & 2 ** bits_per_dim - 1 - return np.array([x, y, z]) - - def get_chunk_id(self, node_id: Optional[np.uint64] = None, - layer: Optional[int] = None, - x: Optional[int] = None, - y: Optional[int] = None, - z: Optional[int] = None) -> np.uint64: - """ (1) Extract Chunk ID from Node ID - (2) Build Chunk ID from Layer, X, Y and Z components - - :param node_id: np.uint64 - :param layer: int - :param x: int - :param y: int - :param z: int - :return: np.uint64 - """ - assert node_id is not None or \ - all(v is not None for v in [layer, x, y, z]) - - if node_id is not None: - layer = self.get_chunk_layer(node_id) - bits_per_dim = self.bitmasks[layer] - - if node_id is not None: - chunk_offset = 64 - self._n_bits_for_layer_id - 3 * bits_per_dim - return np.uint64((int(node_id) >> chunk_offset) << chunk_offset) - else: - - if not(x < 2 ** bits_per_dim and - y < 2 ** bits_per_dim and - z < 2 ** bits_per_dim): - raise Exception("Chunk coordinate is out of range for" - "this graph on layer %d with %d bits/dim." - "[%d, %d, %d]; max = %d." - % (layer, bits_per_dim, x, y, z, - 2 ** bits_per_dim)) - - layer_offset = 64 - self._n_bits_for_layer_id - x_offset = layer_offset - bits_per_dim - y_offset = x_offset - bits_per_dim - z_offset = y_offset - bits_per_dim - return np.uint64(layer << layer_offset | x << x_offset | - y << y_offset | z << z_offset) - - def get_chunk_ids_from_node_ids(self, node_ids: Iterable[np.uint64] - ) -> np.ndarray: - """ Extract a list of Chunk IDs from a list of Node IDs - - :param node_ids: np.ndarray(dtype=np.uint64) - :return: np.ndarray(dtype=np.uint64) - """ - if len(node_ids) == 0: - return np.array([], dtype=np.int) - - return self._get_chunk_id_vec(node_ids) - - def get_child_chunk_ids(self, node_or_chunk_id: np.uint64) -> np.ndarray: - """ Calculates the ids of the children chunks in the next lower layer - - :param node_or_chunk_id: np.uint64 - :return: np.ndarray - """ - chunk_coords = self.get_chunk_coordinates(node_or_chunk_id) - chunk_layer = self.get_chunk_layer(node_or_chunk_id) - - if chunk_layer == 1: - return np.array([]) - elif chunk_layer == 2: - x, y, z = chunk_coords - return np.array([self.get_chunk_id(layer=chunk_layer-1, - x=x, y=y, z=z)]) - else: - chunk_ids = [] - for dcoord in itertools.product(*[range(self.fan_out)]*3): - x, y, z = chunk_coords * self.fan_out + np.array(dcoord) - child_chunk_id = self.get_chunk_id(layer=chunk_layer-1, - x=x, y=y, z=z) - chunk_ids.append(child_chunk_id) - - return np.array(chunk_ids) - - def get_parent_chunk_ids(self, node_or_chunk_id: np.uint64) -> np.ndarray: - """ Creates list of chunk parent ids - - :param node_or_chunk_id: np.uint64 - :return: np.ndarray - """ - parent_chunk_layers = range(self.get_chunk_layer(node_or_chunk_id) + 1, - self.n_layers + 1) - chunk_coord = self.get_chunk_coordinates(node_or_chunk_id) - - parent_chunk_ids = [self.get_chunk_id(node_or_chunk_id)] - for layer in parent_chunk_layers: - chunk_coord = chunk_coord // self.fan_out - parent_chunk_ids.append(self.get_chunk_id(layer=layer, - x=chunk_coord[0], - y=chunk_coord[1], - z=chunk_coord[2])) - return np.array(parent_chunk_ids, dtype=np.uint64) - - def get_parent_chunk_id_dict(self, node_or_chunk_id: np.uint64) -> dict: - """ Creates dict of chunk parent ids - - :param node_or_chunk_id: np.uint64 - :return: dict - """ - chunk_layer = self.get_chunk_layer(node_or_chunk_id) - return dict(zip(range(chunk_layer, self.n_layers + 1), - self.get_parent_chunk_ids(node_or_chunk_id))) - - def get_segment_id_limit(self, node_or_chunk_id: np.uint64) -> np.uint64: - """ Get maximum possible Segment ID for given Node ID or Chunk ID - - :param node_or_chunk_id: np.uint64 - :return: np.uint64 - """ - - layer = self.get_chunk_layer(node_or_chunk_id) - bits_per_dim = self.bitmasks[layer] - chunk_offset = 64 - self._n_bits_for_layer_id - 3 * bits_per_dim - return np.uint64(2 ** chunk_offset - 1) - - def get_segment_id(self, node_id: np.uint64) -> np.uint64: - """ Extract Segment ID from Node ID - - :param node_id: np.uint64 - :return: np.uint64 - """ - - return node_id & self.get_segment_id_limit(node_id) - - def get_node_id(self, segment_id: np.uint64, - chunk_id: Optional[np.uint64] = None, - layer: Optional[int] = None, - x: Optional[int] = None, - y: Optional[int] = None, - z: Optional[int] = None) -> np.uint64: - """ (1) Build Node ID from Segment ID and Chunk ID - (2) Build Node ID from Segment ID, Layer, X, Y and Z components - - :param segment_id: np.uint64 - :param chunk_id: np.uint64 - :param layer: int - :param x: int - :param y: int - :param z: int - :return: np.uint64 - """ - - if chunk_id is not None: - return chunk_id | segment_id - else: - return self.get_chunk_id(layer=layer, x=x, y=y, z=z) | segment_id - - def _get_unique_range(self, row_key, step): - column = column_keys.Concurrency.CounterID - - # Incrementer row keys start with an "i" followed by the chunk id - append_row = self.table.row(row_key, append=True) - append_row.increment_cell_value(column.family_id, column.key, step) - - # This increments the row entry and returns the value AFTER incrementing - latest_row = append_row.commit() - max_segment_id = column.deserialize(latest_row[column.family_id][column.key][0][0]) - - min_segment_id = max_segment_id + np.uint64(1) - step - return min_segment_id, max_segment_id - - def get_unique_segment_id_root_row(self, step: int = 1, - counter_id: int = None) -> np.ndarray: - """ Return unique Segment ID for the Root Chunk - - atomic counter - - :param step: int - :param counter_id: np.uint64 - :return: np.uint64 - """ - if self.n_bits_root_counter == 0: - return self.get_unique_segment_id_range(self.root_chunk_id, - step=step) - - n_counters = np.uint64(2 ** self._n_bits_root_counter) - - if counter_id is None: - counter_id = np.uint64(np.random.randint(0, n_counters)) - else: - counter_id = np.uint64(counter_id % n_counters) - - row_key = serializers.serialize_key( - f"i{serializers.pad_node_id(self.root_chunk_id)}_{counter_id}") - - min_segment_id, max_segment_id = self._get_unique_range(row_key=row_key, - step=step) - - segment_id_range = np.arange(min_segment_id * n_counters + counter_id, - max_segment_id * n_counters + - np.uint64(1) + counter_id, n_counters, - dtype=basetypes.SEGMENT_ID) - - return segment_id_range - - def get_unique_segment_id_range(self, chunk_id: np.uint64, step: int = 1 - ) -> np.ndarray: - """ Return unique Segment ID for given Chunk ID - - atomic counter - - :param chunk_id: np.uint64 - :param step: int - :return: np.uint64 - """ - if self.n_layers == self.get_chunk_layer(chunk_id) and \ - self.n_bits_root_counter > 0: - return self.get_unique_segment_id_root_row(step=step) - - row_key = serializers.serialize_key( - "i%s" % serializers.pad_node_id(chunk_id)) - min_segment_id, max_segment_id = self._get_unique_range(row_key=row_key, - step=step) - segment_id_range = np.arange(min_segment_id, - max_segment_id + np.uint64(1), - dtype=basetypes.SEGMENT_ID) - return segment_id_range - - def get_unique_segment_id(self, chunk_id: np.uint64) -> np.uint64: - """ Return unique Segment ID for given Chunk ID - - atomic counter - - :param chunk_id: np.uint64 - :param step: int - :return: np.uint64 - """ - - return self.get_unique_segment_id_range(chunk_id=chunk_id, step=1)[0] - - def get_unique_node_id_range(self, chunk_id: np.uint64, step: int = 1 - ) -> np.ndarray: - """ Return unique Node ID range for given Chunk ID - - atomic counter - - :param chunk_id: np.uint64 - :param step: int - :return: np.uint64 - """ - - segment_ids = self.get_unique_segment_id_range(chunk_id=chunk_id, - step=step) - - node_ids = np.array([self.get_node_id(segment_id, chunk_id) - for segment_id in segment_ids], dtype=np.uint64) - return node_ids - - def get_unique_node_id(self, chunk_id: np.uint64) -> np.uint64: - """ Return unique Node ID for given Chunk ID - - atomic counter - - :param chunk_id: np.uint64 - :return: np.uint64 - """ - - return self.get_unique_node_id_range(chunk_id=chunk_id, step=1)[0] - - def get_max_seg_id_root_chunk(self) -> np.uint64: - """ Gets maximal root id based on the atomic counter - - This is an approximation. It is not guaranteed that all ids smaller or - equal to this id exists. However, it is guaranteed that no larger id - exist at the time this function is executed. - - :return: uint64 - """ - if self.n_bits_root_counter == 0: - return self.get_max_seg_id(self.root_chunk_id) - - n_counters = np.uint64(2 ** self.n_bits_root_counter) - max_value = 0 - for counter_id in range(n_counters): - row_key = serializers.serialize_key( - f"i{serializers.pad_node_id(self.root_chunk_id)}_{counter_id}") - - row = self.read_byte_row(row_key, - columns=column_keys.Concurrency.CounterID) - - counter = basetypes.SEGMENT_ID.type(row[0].value if row else 0) * \ - n_counters - if counter > max_value: - max_value = counter - - return max_value - - def get_max_seg_id(self, chunk_id: np.uint64) -> np.uint64: - """ Gets maximal seg id in a chunk based on the atomic counter - - This is an approximation. It is not guaranteed that all ids smaller or - equal to this id exists. However, it is guaranteed that no larger id - exist at the time this function is executed. - - - :return: uint64 - """ - if self.n_layers == self.get_chunk_layer(chunk_id) and \ - self.n_bits_root_counter > 0: - return self.get_max_seg_id_root_chunk() - - # Incrementer row keys start with an "i" - row_key = serializers.serialize_key( - "i%s" % serializers.pad_node_id(chunk_id)) - row = self.read_byte_row(row_key, - columns=column_keys.Concurrency.CounterID) - - # Read incrementer value (default to 0) and interpret is as Segment ID - return basetypes.SEGMENT_ID.type(row[0].value if row else 0) - - def get_max_node_id(self, chunk_id: np.uint64) -> np.uint64: - """ Gets maximal node id in a chunk based on the atomic counter - - This is an approximation. It is not guaranteed that all ids smaller or - equal to this id exists. However, it is guaranteed that no larger id - exist at the time this function is executed. - - - :return: uint64 - """ - - max_seg_id = self.get_max_seg_id(chunk_id) - return self.get_node_id(segment_id=max_seg_id, chunk_id=chunk_id) - - def get_unique_operation_id(self) -> np.uint64: - """ Finds a unique operation id - - atomic counter - - Operations essentially live in layer 0. Even if segmentation ids might - live in layer 0 one day, they would not collide with the operation ids - because we write information belonging to operations in a separate - family id. - - :return: str - """ - column = column_keys.Concurrency.CounterID - - append_row = self.table.row(row_keys.OperationID, append=True) - append_row.increment_cell_value(column.family_id, column.key, 1) - - # This increments the row entry and returns the value AFTER incrementing - latest_row = append_row.commit() - operation_id_b = latest_row[column.family_id][column.key][0][0] - operation_id = column.deserialize(operation_id_b) - - return np.uint64(operation_id) - - def get_max_operation_id(self) -> np.int64: - """ Gets maximal operation id based on the atomic counter - - This is an approximation. It is not guaranteed that all ids smaller or - equal to this id exists. However, it is guaranteed that no larger id - exist at the time this function is executed. - - - :return: int64 - """ - column = column_keys.Concurrency.CounterID - row = self.read_byte_row(row_keys.OperationID, columns=column) - - return row[0].value if row else column.basetype(0) - - def get_cross_chunk_edges_layer(self, cross_edges): - """ Computes the layer in which a cross chunk edge becomes relevant. - - I.e. if a cross chunk edge links two nodes in layer 4 this function - returns 3. - - :param cross_edges: n x 2 array - edges between atomic (level 1) node ids - :return: array of length n - """ - if len(cross_edges) == 0: - return np.array([], dtype=np.int) - - cross_chunk_edge_layers = np.ones(len(cross_edges), dtype=np.int) - - cross_edge_coordinates = [] - for cross_edge in cross_edges: - cross_edge_coordinates.append( - [self.get_chunk_coordinates(cross_edge[0]), - self.get_chunk_coordinates(cross_edge[1])]) - - cross_edge_coordinates = np.array(cross_edge_coordinates, dtype=np.int) - - for layer in range(2, self.n_layers): - edge_diff = np.sum(np.abs(cross_edge_coordinates[:, 0] - - cross_edge_coordinates[:, 1]), axis=1) - cross_chunk_edge_layers[edge_diff > 0] += 1 - cross_edge_coordinates = cross_edge_coordinates // self.fan_out - - return cross_chunk_edge_layers - - def get_cross_chunk_edge_dict(self, cross_edges): - """ Generates a cross chunk edge dict for a list of cross chunk edges - - :param cross_edges: n x 2 array - :return: dict - """ - cce_layers = self.get_cross_chunk_edges_layer(cross_edges) - u_cce_layers = np.unique(cce_layers) - cross_edge_dict = {} - - for l in range(2, self.n_layers): - cross_edge_dict[l] = column_keys.Connectivity.CrossChunkEdge.deserialize(b'') - - val_dict = {} - for cc_layer in u_cce_layers: - layer_cross_edges = cross_edges[cce_layers == cc_layer] - - if len(layer_cross_edges) > 0: - val_dict[column_keys.Connectivity.CrossChunkEdge[cc_layer]] = \ - layer_cross_edges - cross_edge_dict[cc_layer] = layer_cross_edges - return cross_edge_dict - - def read_byte_rows( - self, - start_key: Optional[bytes] = None, - end_key: Optional[bytes] = None, - end_key_inclusive: bool = False, - row_keys: Optional[Iterable[bytes]] = None, - columns: Optional[Union[Iterable[column_keys._Column], column_keys._Column]] = None, - start_time: Optional[datetime.datetime] = None, - end_time: Optional[datetime.datetime] = None, - end_time_inclusive: bool = False) -> Dict[bytes, Union[ - Dict[column_keys._Column, List[bigtable.row_data.Cell]], - List[bigtable.row_data.Cell] - ]]: - """Main function for reading a row range or non-contiguous row sets from Bigtable using - `bytes` keys. - - Keyword Arguments: - start_key {Optional[bytes]} -- The first row to be read, ignored if `row_keys` is set. - If None, no lower boundary is used. (default: {None}) - end_key {Optional[bytes]} -- The end of the row range, ignored if `row_keys` is set. - If None, no upper boundary is used. (default: {None}) - end_key_inclusive {bool} -- Whether or not `end_key` itself should be included in the - request, ignored if `row_keys` is set or `end_key` is None. (default: {False}) - row_keys {Optional[Iterable[bytes]]} -- An `Iterable` containing possibly - non-contiguous row keys. Takes precedence over `start_key` and `end_key`. - (default: {None}) - columns {Optional[Union[Iterable[column_keys._Column], column_keys._Column]]} -- - Optional filtering by columns to speed up the query. If `columns` is a single - column (not iterable), the column key will be omitted from the result. - (default: {None}) - start_time {Optional[datetime.datetime]} -- Ignore cells with timestamp before - `start_time`. If None, no lower bound. (default: {None}) - end_time {Optional[datetime.datetime]} -- Ignore cells with timestamp after `end_time`. - If None, no upper bound. (default: {None}) - end_time_inclusive {bool} -- Whether or not `end_time` itself should be included in the - request, ignored if `end_time` is None. (default: {False}) - - Returns: - Dict[bytes, Union[Dict[column_keys._Column, List[bigtable.row_data.Cell]], - List[bigtable.row_data.Cell]]] -- - Returns a dictionary of `byte` rows as keys. Their value will be a mapping of - columns to a List of cells (one cell per timestamp). Each cell has a `value` - property, which returns the deserialized field, and a `timestamp` property, which - returns the timestamp as `datetime.datetime` object. - If only a single `column_keys._Column` was requested, the List of cells will be - attached to the row dictionary directly (skipping the column dictionary). - """ - - # Create filters: Column and Time - filter_ = get_time_range_and_column_filter( - columns=columns, - start_time=start_time, - end_time=end_time, - end_inclusive=end_time_inclusive) - - # Create filters: Rows - row_set = RowSet() - - if row_keys is not None: - for row_key in row_keys: - row_set.add_row_key(row_key) - elif start_key is not None and end_key is not None: - row_set.add_row_range_from_keys( - start_key=start_key, - start_inclusive=True, - end_key=end_key, - end_inclusive=end_key_inclusive) - else: - raise cg_exceptions.PreconditionError("Need to either provide a valid set of rows, or" - " both, a start row and an end row.") - - # Bigtable read with retries - rows = self._execute_read(row_set=row_set, row_filter=filter_) - - # Deserialize cells - for row_key, column_dict in rows.items(): - for column, cell_entries in column_dict.items(): - for cell_entry in cell_entries: - cell_entry.value = column.deserialize(cell_entry.value) - # If no column array was requested, reattach single column's values directly to the row - if isinstance(columns, column_keys._Column): - rows[row_key] = cell_entries - - return rows - - def read_byte_row( - self, - row_key: bytes, - columns: Optional[Union[Iterable[column_keys._Column], column_keys._Column]] = None, - start_time: Optional[datetime.datetime] = None, - end_time: Optional[datetime.datetime] = None, - end_time_inclusive: bool = False) -> \ - Union[Dict[column_keys._Column, List[bigtable.row_data.Cell]], - List[bigtable.row_data.Cell]]: - """Convenience function for reading a single row from Bigtable using its `bytes` keys. - - Arguments: - row_key {bytes} -- The row to be read. - - Keyword Arguments: - columns {Optional[Union[Iterable[column_keys._Column], column_keys._Column]]} -- - Optional filtering by columns to speed up the query. If `columns` is a single - column (not iterable), the column key will be omitted from the result. - (default: {None}) - start_time {Optional[datetime.datetime]} -- Ignore cells with timestamp before - `start_time`. If None, no lower bound. (default: {None}) - end_time {Optional[datetime.datetime]} -- Ignore cells with timestamp after `end_time`. - If None, no upper bound. (default: {None}) - end_time_inclusive {bool} -- Whether or not `end_time` itself should be included in the - request, ignored if `end_time` is None. (default: {False}) - - Returns: - Union[Dict[column_keys._Column, List[bigtable.row_data.Cell]], - List[bigtable.row_data.Cell]] -- - Returns a mapping of columns to a List of cells (one cell per timestamp). Each cell - has a `value` property, which returns the deserialized field, and a `timestamp` - property, which returns the timestamp as `datetime.datetime` object. - If only a single `column_keys._Column` was requested, the List of cells is returned - directly. - """ - row = self.read_byte_rows(row_keys=[row_key], columns=columns, start_time=start_time, - end_time=end_time, end_time_inclusive=end_time_inclusive) - - if isinstance(columns, column_keys._Column): - return row.get(row_key, []) - else: - return row.get(row_key, {}) - - def read_node_id_rows( - self, - start_id: Optional[np.uint64] = None, - end_id: Optional[np.uint64] = None, - end_id_inclusive: bool = False, - node_ids: Optional[Iterable[np.uint64]] = None, - columns: Optional[Union[Iterable[column_keys._Column], column_keys._Column]] = None, - start_time: Optional[datetime.datetime] = None, - end_time: Optional[datetime.datetime] = None, - end_time_inclusive: bool = False) -> Dict[np.uint64, Union[ - Dict[column_keys._Column, List[bigtable.row_data.Cell]], - List[bigtable.row_data.Cell] - ]]: - """Convenience function for reading a row range or non-contiguous row sets from Bigtable - representing NodeIDs. - - Keyword Arguments: - start_id {Optional[np.uint64]} -- The first row to be read, ignored if `node_ids` is - set. If None, no lower boundary is used. (default: {None}) - end_id {Optional[np.uint64]} -- The end of the row range, ignored if `node_ids` is set. - If None, no upper boundary is used. (default: {None}) - end_id_inclusive {bool} -- Whether or not `end_id` itself should be included in the - request, ignored if `node_ids` is set or `end_id` is None. (default: {False}) - node_ids {Optional[Iterable[np.uint64]]} -- An `Iterable` containing possibly - non-contiguous row keys. Takes precedence over `start_id` and `end_id`. - (default: {None}) - columns {Optional[Union[Iterable[column_keys._Column], column_keys._Column]]} -- - Optional filtering by columns to speed up the query. If `columns` is a single - column (not iterable), the column key will be omitted from the result. - (default: {None}) - start_time {Optional[datetime.datetime]} -- Ignore cells with timestamp before - `start_time`. If None, no lower bound. (default: {None}) - end_time {Optional[datetime.datetime]} -- Ignore cells with timestamp after `end_time`. - If None, no upper bound. (default: {None}) - end_time_inclusive {bool} -- Whether or not `end_time` itself should be included in the - request, ignored if `end_time` is None. (default: {False}) - - Returns: - Dict[np.uint64, Union[Dict[column_keys._Column, List[bigtable.row_data.Cell]], - List[bigtable.row_data.Cell]]] -- - Returns a dictionary of NodeID rows as keys. Their value will be a mapping of - columns to a List of cells (one cell per timestamp). Each cell has a `value` - property, which returns the deserialized field, and a `timestamp` property, which - returns the timestamp as `datetime.datetime` object. - If only a single `column_keys._Column` was requested, the List of cells will be - attached to the row dictionary directly (skipping the column dictionary). - """ - to_bytes = serializers.serialize_uint64 - from_bytes = serializers.deserialize_uint64 - - # Read rows (convert Node IDs to row_keys) - rows = self.read_byte_rows( - start_key=to_bytes(start_id) if start_id is not None else None, - end_key=to_bytes(end_id) if end_id is not None else None, - end_key_inclusive=end_id_inclusive, - row_keys=(to_bytes(node_id) for node_id in node_ids) if node_ids is not None else None, - columns=columns, - start_time=start_time, - end_time=end_time, - end_time_inclusive=end_time_inclusive) - - # Convert row_keys back to Node IDs - return {from_bytes(row_key): data for (row_key, data) in rows.items()} - - def read_node_id_row( - self, - node_id: np.uint64, - columns: Optional[Union[Iterable[column_keys._Column], column_keys._Column]] = None, - start_time: Optional[datetime.datetime] = None, - end_time: Optional[datetime.datetime] = None, - end_time_inclusive: bool = False) -> \ - Union[Dict[column_keys._Column, List[bigtable.row_data.Cell]], - List[bigtable.row_data.Cell]]: - """Convenience function for reading a single row from Bigtable, representing a NodeID. - - Arguments: - node_id {np.uint64} -- the NodeID of the row to be read. - - Keyword Arguments: - columns {Optional[Union[Iterable[column_keys._Column], column_keys._Column]]} -- - Optional filtering by columns to speed up the query. If `columns` is a single - column (not iterable), the column key will be omitted from the result. - (default: {None}) - start_time {Optional[datetime.datetime]} -- Ignore cells with timestamp before - `start_time`. If None, no lower bound. (default: {None}) - end_time {Optional[datetime.datetime]} -- Ignore cells with timestamp after `end_time`. - If None, no upper bound. (default: {None}) - end_time_inclusive {bool} -- Whether or not `end_time` itself should be included in the - request, ignored if `end_time` is None. (default: {False}) - - Returns: - Union[Dict[column_keys._Column, List[bigtable.row_data.Cell]], - List[bigtable.row_data.Cell]] -- - Returns a mapping of columns to a List of cells (one cell per timestamp). Each cell - has a `value` property, which returns the deserialized field, and a `timestamp` - property, which returns the timestamp as `datetime.datetime` object. - If only a single `column_keys._Column` was requested, the List of cells is returned - directly. - """ - return self.read_byte_row(row_key=serializers.serialize_uint64(node_id), columns=columns, - start_time=start_time, end_time=end_time, - end_time_inclusive=end_time_inclusive) - - def read_cross_chunk_edges(self, node_id: np.uint64, start_layer: int = 2, - end_layer: int = None) -> Dict: - """ Reads the cross chunk edge entry from the table for a given node id - and formats it as cross edge dict - - :param node_id: - :param start_layer: - :param end_layer: - :return: - """ - if end_layer is None: - end_layer = self.n_layers - - if start_layer < 2 or start_layer == self.n_layers: - return {} - - start_layer = np.max([self.get_chunk_layer(node_id), start_layer]) - - assert end_layer > start_layer and end_layer <= self.n_layers - - columns = [column_keys.Connectivity.CrossChunkEdge[l] - for l in range(start_layer, end_layer)] - row_dict = self.read_node_id_row(node_id, columns=columns) - - cross_edge_dict = {} - for l in range(start_layer, end_layer): - col = column_keys.Connectivity.CrossChunkEdge[l] - if col in row_dict: - cross_edge_dict[l] = row_dict[col][0].value - else: - cross_edge_dict[l] = col.deserialize(b'') - - return cross_edge_dict - - def mutate_row(self, row_key: bytes, - val_dict: Dict[column_keys._Column, Any], - time_stamp: Optional[datetime.datetime] = None, - isbytes: bool = False - ) -> bigtable.row.Row: - """ Mutates a single row - - :param row_key: serialized bigtable row key - :param val_dict: Dict[column_keys._TypedColumn: bytes] - :param time_stamp: None or datetime - :return: list - """ - row = self.table.row(row_key) - - for column, value in val_dict.items(): - if not isbytes: - value = column.serialize(value) - - row.set_cell(column_family_id=column.family_id, - column=column.key, - value=value, - timestamp=time_stamp) - return row - - def bulk_write(self, rows: Iterable[bigtable.row.DirectRow], - root_ids: Optional[Union[np.uint64, - Iterable[np.uint64]]] = None, - operation_id: Optional[np.uint64] = None, - slow_retry: bool = True, - block_size: int = 2000): - """ Writes a list of mutated rows in bulk - - WARNING: If contains the same row (same row_key) and column - key two times only the last one is effectively written to the BigTable - (even when the mutations were applied to different columns) - --> no versioning! - - :param rows: list - list of mutated rows - :param root_ids: list if uint64 - :param operation_id: uint64 or None - operation_id (or other unique id) that *was* used to lock the root - the bulk write is only executed if the root is still locked with - the same id. - :param slow_retry: bool - :param block_size: int - """ - if slow_retry: - initial = 5 - else: - initial = 1 - - retry_policy = Retry( - predicate=if_exception_type((Aborted, - DeadlineExceeded, - ServiceUnavailable)), - initial=initial, - maximum=15.0, - multiplier=2.0, - deadline=LOCK_EXPIRED_TIME_DELTA.seconds) - - if root_ids is not None and operation_id is not None: - if isinstance(root_ids, int): - root_ids = [root_ids] - - if not self.check_and_renew_root_locks(root_ids, operation_id): - raise cg_exceptions.LockError(f"Root lock renewal failed for operation ID {operation_id}") - - for i_row in range(0, len(rows), block_size): - status = self.table.mutate_rows(rows[i_row: i_row + block_size], - retry=retry_policy) - - if not all(status): - raise cg_exceptions.ChunkedGraphError(f"Bulk write failed for operation ID {operation_id}") - - def _execute_read_thread(self, row_set_and_filter: Tuple[RowSet, RowFilter]): - row_set, row_filter = row_set_and_filter - if not row_set.row_keys and not row_set.row_ranges: - # Check for everything falsy, because Bigtable considers even empty - # lists of row_keys as no upper/lower bound! - return {} - - range_read = self.table.read_rows(row_set=row_set, filter_=row_filter) - res = {v.row_key: partial_row_data_to_column_dict(v) - for v in range_read} - return res - - def _execute_read(self, row_set: RowSet, row_filter: RowFilter = None) \ - -> Dict[bytes, Dict[column_keys._Column, bigtable.row_data.PartialRowData]]: - """ Core function to read rows from Bigtable. Uses standard Bigtable retry logic - :param row_set: BigTable RowSet - :param row_filter: BigTable RowFilter - :return: Dict[bytes, Dict[column_keys._Column, bigtable.row_data.PartialRowData]] - """ - - # FIXME: Bigtable limits the length of the serialized request to 512 KiB. We should - # calculate this properly (range_read.request.SerializeToString()), but this estimate is - # good enough for now - max_row_key_count = 20000 - n_subrequests = max(1, int(np.ceil(len(row_set.row_keys) / - max_row_key_count))) - n_threads = min(n_subrequests, 2 * mu.n_cpus) - - row_sets = [] - for i in range(n_subrequests): - r = RowSet() - r.row_keys = row_set.row_keys[i * max_row_key_count: - (i + 1) * max_row_key_count] - row_sets.append(r) - - # Don't forget the original RowSet's row_ranges - row_sets[0].row_ranges = row_set.row_ranges - - responses = mu.multithread_func(self._execute_read_thread, - params=((r, row_filter) - for r in row_sets), - debug=n_threads == 1, - n_threads=n_threads) - - combined_response = {} - for resp in responses: - combined_response.update(resp) - - return combined_response - - def range_read_chunk( - self, - layer: Optional[int] = None, - x: Optional[int] = None, - y: Optional[int] = None, - z: Optional[int] = None, - chunk_id: Optional[np.uint64] = None, - columns: Optional[Union[Iterable[column_keys._Column], column_keys._Column]] = None, - time_stamp: Optional[datetime.datetime] = None) -> Dict[np.uint64, Union[ - Dict[column_keys._Column, List[bigtable.row_data.Cell]], - List[bigtable.row_data.Cell] - ]]: - """Convenience function for reading all NodeID rows of a single chunk from Bigtable. - Chunk can either be specified by its (layer, x, y, and z coordinate), or by the chunk ID. - - Keyword Arguments: - layer {Optional[int]} -- The layer of the chunk within the graph (default: {None}) - x {Optional[int]} -- The xth chunk in x dimension within the graph, within `layer`. - (default: {None}) - y {Optional[int]} -- The yth chunk in y dimension within the graph, within `layer`. - (default: {None}) - z {Optional[int]} -- The zth chunk in z dimension within the graph, within `layer`. - (default: {None}) - chunk_id {Optional[np.uint64]} -- Alternative way to specify the chunk, if the Chunk ID - is already known. (default: {None}) - columns {Optional[Union[Iterable[column_keys._Column], column_keys._Column]]} -- - Optional filtering by columns to speed up the query. If `columns` is a single - column (not iterable), the column key will be omitted from the result. - (default: {None}) - time_stamp {Optional[datetime.datetime]} -- Ignore cells with timestamp after `end_time`. - If None, no upper bound. (default: {None}) - - Returns: - Dict[np.uint64, Union[Dict[column_keys._Column, List[bigtable.row_data.Cell]], - List[bigtable.row_data.Cell]]] -- - Returns a dictionary of NodeID rows as keys. Their value will be a mapping of - columns to a List of cells (one cell per timestamp). Each cell has a `value` - property, which returns the deserialized field, and a `timestamp` property, which - returns the timestamp as `datetime.datetime` object. - If only a single `column_keys._Column` was requested, the List of cells will be - attached to the row dictionary directly (skipping the column dictionary). - """ - if chunk_id is not None: - x, y, z = self.get_chunk_coordinates(chunk_id) - layer = self.get_chunk_layer(chunk_id) - elif layer is not None and x is not None and y is not None and z is not None: - chunk_id = self.get_chunk_id(layer=layer, x=x, y=y, z=z) - else: - raise Exception("Either chunk_id or layer and coordinates have to be defined") - - if layer == 1: - max_segment_id = self.get_segment_id_limit(chunk_id) - else: - max_segment_id = self.get_max_seg_id(chunk_id=chunk_id) - - # Define BigTable keys - start_id = self.get_node_id(np.uint64(0), chunk_id=chunk_id) - end_id = self.get_node_id(max_segment_id, chunk_id=chunk_id) - - try: - rr = self.read_node_id_rows( - start_id=start_id, - end_id=end_id, - end_id_inclusive=True, - columns=columns, - end_time=time_stamp, - end_time_inclusive=True) - except Exception as err: - raise Exception("Unable to consume chunk read: " - "[%d, %d, %d], l = %d: %s" % - (x, y, z, layer, err)) - return rr - - def range_read_layer(self, layer_id: int): - """ Reads all ids within a layer - - This can take a while depending on the size of the graph - - :param layer_id: int - :return: list of rows - """ - raise NotImplementedError() - - def test_if_nodes_are_in_same_chunk(self, node_ids: Sequence[np.uint64] - ) -> bool: - """ Test whether two nodes are in the same chunk - - :param node_ids: list of two ints - :return: bool - """ - assert len(node_ids) == 2 - return self.get_chunk_id(node_id=node_ids[0]) == \ - self.get_chunk_id(node_id=node_ids[1]) - - def get_chunk_id_from_coord(self, layer: int, - x: int, y: int, z: int) -> np.uint64: - """ Return ChunkID for given chunked graph layer and voxel coordinates. - - :param layer: int -- ChunkedGraph layer - :param x: int -- X coordinate in voxel - :param y: int -- Y coordinate in voxel - :param z: int -- Z coordinate in voxel - :return: np.uint64 -- ChunkID - """ - base_chunk_span = int(self.fan_out) ** max(0, layer - 2) - - return self.get_chunk_id( - layer=layer, - x=x // (int(self.chunk_size[0]) * base_chunk_span), - y=y // (int(self.chunk_size[1]) * base_chunk_span), - z=z // (int(self.chunk_size[2]) * base_chunk_span)) - - def get_atomic_id_from_coord(self, x: int, y: int, z: int, - parent_id: np.uint64, n_tries: int=5 - ) -> np.uint64: - """ Determines atomic id given a coordinate - - :param x: int - :param y: int - :param z: int - :param parent_id: np.uint64 - :param n_tries: int - :return: np.uint64 or None - """ - if self.get_chunk_layer(parent_id) == 1: - return parent_id - - - x /= 2**self.cv_mip - y /= 2**self.cv_mip - - x = int(x) - y = int(y) - z = int(z) - - checked = [] - atomic_id = None - root_id = self.get_root(parent_id) - - for i_try in range(n_tries): - - # Define block size -- increase by one each try - x_l = x - (i_try - 1)**2 - y_l = y - (i_try - 1)**2 - z_l = z - (i_try - 1)**2 - - x_h = x + 1 + (i_try - 1)**2 - y_h = y + 1 + (i_try - 1)**2 - z_h = z + 1 + (i_try - 1)**2 - - if x_l < 0: - x_l = 0 - - if y_l < 0: - y_l = 0 - - if z_l < 0: - z_l = 0 - - # Get atomic ids from cloudvolume - atomic_id_block = self.cv[x_l: x_h, y_l: y_h, z_l: z_h] - atomic_ids, atomic_id_count = np.unique(atomic_id_block, - return_counts=True) - - # sort by frequency and discard those ids that have been checked - # previously - sorted_atomic_ids = atomic_ids[np.argsort(atomic_id_count)] - sorted_atomic_ids = sorted_atomic_ids[~np.in1d(sorted_atomic_ids, - checked)] - - # For each candidate id check whether its root id corresponds to the - # given root id - for candidate_atomic_id in sorted_atomic_ids: - ass_root_id = self.get_root(candidate_atomic_id) - - if ass_root_id == root_id: - # atomic_id is not None will be our indicator that the - # search was successful - - atomic_id = candidate_atomic_id - break - else: - checked.append(candidate_atomic_id) - - if atomic_id is not None: - break - - # Returns None if unsuccessful - return atomic_id - - def read_log_row( - self, operation_id: np.uint64 - ) -> Dict[column_keys._Column, Union[np.ndarray, np.number]]: - """ Retrieves log record from Bigtable for a given operation ID - - :param operation_id: np.uint64 - :return: Dict[column_keys._Column, Union[np.ndarray, np.number]] - """ - columns = [ - column_keys.OperationLogs.UndoOperationID, - column_keys.OperationLogs.RedoOperationID, - column_keys.OperationLogs.UserID, - column_keys.OperationLogs.RootID, - column_keys.OperationLogs.SinkID, - column_keys.OperationLogs.SourceID, - column_keys.OperationLogs.SourceCoordinate, - column_keys.OperationLogs.SinkCoordinate, - column_keys.OperationLogs.AddedEdge, - column_keys.OperationLogs.Affinity, - column_keys.OperationLogs.RemovedEdge, - column_keys.OperationLogs.BoundingBoxOffset, - ] - log_record = self.read_node_id_row(operation_id, columns=columns) - log_record.update((column, v[0].value) for column, v in log_record.items()) - return log_record - - def read_first_log_row(self): - """ Returns first log row - - :return: None or dict - """ - - for operation_id in range(1, 100): - log_row = self.read_log_row(np.uint64(operation_id)) - - if len(log_row) > 0: - return log_row - - return None - - def add_atomic_edges_in_chunks(self, edge_id_dict: dict, - edge_aff_dict: dict, edge_area_dict: dict, - isolated_node_ids: Sequence[np.uint64], - verbose: bool = True, - time_stamp: Optional[datetime.datetime] = None): - """ Creates atomic nodes in first abstraction layer for a SINGLE chunk - and all abstract nodes in the second for the same chunk - - Alle edges (edge_ids) need to be from one chunk and no nodes should - exist for this chunk prior to calling this function. All cross edges - (cross_edge_ids) have to point out the chunk (first entry is the id - within the chunk) - - :param edge_id_dict: dict - :param edge_aff_dict: dict - :param edge_area_dict: dict - :param isolated_node_ids: list of uint64s - ids of nodes that have no edge in the chunked graph - :param verbose: bool - :param time_stamp: datetime - """ - if time_stamp is None: - time_stamp = datetime.datetime.utcnow() - - if time_stamp.tzinfo is None: - time_stamp = UTC.localize(time_stamp) - - # Comply to resolution of BigTables TimeRange - time_stamp = get_google_compatible_time_stamp(time_stamp, - round_up=False) - - edge_id_keys = ["in_connected", "in_disconnected", "cross", - "between_connected", "between_disconnected"] - edge_aff_keys = ["in_connected", "in_disconnected", "between_connected", - "between_disconnected"] - - # Check if keys exist and include an empty array if not - n_edge_ids = 0 - chunk_id = None - for edge_id_key in edge_id_keys: - if not edge_id_key in edge_id_dict: - empty_edges = np.array([], dtype=np.uint64).reshape(0, 2) - edge_id_dict[edge_id_key] = empty_edges - else: - n_edge_ids += len(edge_id_dict[edge_id_key]) - - if len(edge_id_dict[edge_id_key]) > 0: - node_id = edge_id_dict[edge_id_key][0, 0] - chunk_id = self.get_chunk_id(node_id) - - for edge_aff_key in edge_aff_keys: - if not edge_aff_key in edge_aff_dict: - edge_aff_dict[edge_aff_key] = np.array([], dtype=np.float32) - - time_start = time.time() - - # Catch trivial case - if n_edge_ids == 0 and len(isolated_node_ids) == 0: - return 0 - - # Make parent id creation easier - if chunk_id is None: - chunk_id = self.get_chunk_id(isolated_node_ids[0]) - - chunk_id_c = self.get_chunk_coordinates(chunk_id) - parent_chunk_id = self.get_chunk_id(layer=2, x=chunk_id_c[0], - y=chunk_id_c[1], z=chunk_id_c[2]) - - # Get connected component within the chunk - chunk_node_ids = np.concatenate([ - isolated_node_ids.astype(np.uint64), - np.unique(edge_id_dict["in_connected"]), - np.unique(edge_id_dict["in_disconnected"]), - np.unique(edge_id_dict["cross"][:, 0]), - np.unique(edge_id_dict["between_connected"][:, 0]), - np.unique(edge_id_dict["between_disconnected"][:, 0])]) - - chunk_node_ids = np.unique(chunk_node_ids) - - node_chunk_ids = np.array([self.get_chunk_id(c) - for c in chunk_node_ids], - dtype=np.uint64) - - u_node_chunk_ids, c_node_chunk_ids = np.unique(node_chunk_ids, - return_counts=True) - if len(u_node_chunk_ids) > 1: - raise Exception("%d: %d chunk ids found in node id list. " - "Some edges might be in the wrong order. " - "Number of occurences:" % - (chunk_id, len(u_node_chunk_ids)), c_node_chunk_ids) - - add_edge_ids = np.vstack([chunk_node_ids, chunk_node_ids]).T - edge_ids = np.concatenate([edge_id_dict["in_connected"].copy(), - add_edge_ids]) - - graph, _, _, unique_graph_ids = flatgraph_utils.build_gt_graph( - edge_ids, make_directed=True) - - ccs = flatgraph_utils.connected_components(graph) - - if verbose: - self.logger.debug("CC in chunk: %.3fs" % (time.time() - time_start)) - - # Add rows for nodes that are in this chunk - # a connected component at a time - node_c = 0 # Just a counter for the log / speed measurement - - n_ccs = len(ccs) - - parent_ids = self.get_unique_node_id_range(parent_chunk_id, step=n_ccs) - time_start = time.time() - - time_dict = collections.defaultdict(list) - - time_start_1 = time.time() - sparse_indices = {} - remapping = {} - for k in edge_id_dict.keys(): - # Circumvent datatype issues - - u_ids, inv_ids = np.unique(edge_id_dict[k], return_inverse=True) - mapped_ids = np.arange(len(u_ids), dtype=np.int32) - remapped_arr = mapped_ids[inv_ids].reshape(edge_id_dict[k].shape) - - sparse_indices[k] = compute_indices_pandas(remapped_arr) - remapping[k] = dict(zip(u_ids, mapped_ids)) - - time_dict["sparse_indices"].append(time.time() - time_start_1) - - rows = [] - - for i_cc, cc in enumerate(ccs): - node_ids = unique_graph_ids[cc] - - u_chunk_ids = np.unique([self.get_chunk_id(n) for n in node_ids]) - - if len(u_chunk_ids) > 1: - self.logger.error(f"Found multiple chunk ids: {u_chunk_ids}") - raise Exception() - - # Create parent id - parent_id = parent_ids[i_cc] - - parent_cross_edges = np.array([], dtype=np.uint64).reshape(0, 2) - - # Add rows for nodes that are in this chunk - for i_node_id, node_id in enumerate(node_ids): - # Extract edges relevant to this node - - # in chunk + connected - time_start_2 = time.time() - if node_id in remapping["in_connected"]: - row_ids, column_ids = sparse_indices["in_connected"][remapping["in_connected"][node_id]] - - inv_column_ids = (column_ids + 1) % 2 - - connected_ids = edge_id_dict["in_connected"][row_ids, inv_column_ids] - connected_affs = edge_aff_dict["in_connected"][row_ids] - connected_areas = edge_area_dict["in_connected"][row_ids] - time_dict["in_connected"].append(time.time() - time_start_2) - time_start_2 = time.time() - else: - connected_ids = np.array([], dtype=np.uint64) - connected_affs = np.array([], dtype=np.float32) - connected_areas = np.array([], dtype=np.uint64) - - # in chunk + disconnected - if node_id in remapping["in_disconnected"]: - row_ids, column_ids = sparse_indices["in_disconnected"][remapping["in_disconnected"][node_id]] - inv_column_ids = (column_ids + 1) % 2 - - disconnected_ids = edge_id_dict["in_disconnected"][row_ids, inv_column_ids] - disconnected_affs = edge_aff_dict["in_disconnected"][row_ids] - disconnected_areas = edge_area_dict["in_disconnected"][row_ids] - time_dict["in_disconnected"].append(time.time() - time_start_2) - time_start_2 = time.time() - else: - disconnected_ids = np.array([], dtype=np.uint64) - disconnected_affs = np.array([], dtype=np.float32) - disconnected_areas = np.array([], dtype=np.uint64) - - # out chunk + connected - if node_id in remapping["between_connected"]: - row_ids, column_ids = sparse_indices["between_connected"][remapping["between_connected"][node_id]] - - row_ids = row_ids[column_ids == 0] - column_ids = column_ids[column_ids == 0] - inv_column_ids = (column_ids + 1) % 2 - time_dict["out_connected_mask"].append(time.time() - time_start_2) - time_start_2 = time.time() - - connected_ids = np.concatenate([connected_ids, edge_id_dict["between_connected"][row_ids, inv_column_ids]]) - connected_affs = np.concatenate([connected_affs, edge_aff_dict["between_connected"][row_ids]]) - connected_areas = np.concatenate([connected_areas, edge_area_dict["between_connected"][row_ids]]) - - parent_cross_edges = np.concatenate([parent_cross_edges, edge_id_dict["between_connected"][row_ids]]) - - time_dict["out_connected"].append(time.time() - time_start_2) - time_start_2 = time.time() - - # out chunk + disconnected - if node_id in remapping["between_disconnected"]: - row_ids, column_ids = sparse_indices["between_disconnected"][remapping["between_disconnected"][node_id]] - - row_ids = row_ids[column_ids == 0] - column_ids = column_ids[column_ids == 0] - inv_column_ids = (column_ids + 1) % 2 - time_dict["out_disconnected_mask"].append(time.time() - time_start_2) - time_start_2 = time.time() - - disconnected_ids = np.concatenate([disconnected_ids, edge_id_dict["between_disconnected"][row_ids, inv_column_ids]]) - disconnected_affs = np.concatenate([disconnected_affs, edge_aff_dict["between_disconnected"][row_ids]]) - disconnected_areas = np.concatenate([disconnected_areas, edge_area_dict["between_disconnected"][row_ids]]) - - time_dict["out_disconnected"].append(time.time() - time_start_2) - time_start_2 = time.time() - - # cross - if node_id in remapping["cross"]: - row_ids, column_ids = sparse_indices["cross"][remapping["cross"][node_id]] - - row_ids = row_ids[column_ids == 0] - column_ids = column_ids[column_ids == 0] - inv_column_ids = (column_ids + 1) % 2 - time_dict["cross_mask"].append(time.time() - time_start_2) - time_start_2 = time.time() - - connected_ids = np.concatenate([connected_ids, edge_id_dict["cross"][row_ids, inv_column_ids]]) - connected_affs = np.concatenate([connected_affs, np.full((len(row_ids)), np.inf, dtype=np.float32)]) - connected_areas = np.concatenate([connected_areas, np.ones((len(row_ids)), dtype=np.uint64)]) - - parent_cross_edges = np.concatenate([parent_cross_edges, edge_id_dict["cross"][row_ids]]) - time_dict["cross"].append(time.time() - time_start_2) - time_start_2 = time.time() - - # Create node - partners = np.concatenate([connected_ids, disconnected_ids]) - affinities = np.concatenate([connected_affs, disconnected_affs]) - areas = np.concatenate([connected_areas, disconnected_areas]) - connected = np.arange(len(connected_ids), dtype=np.int) - - val_dict = {column_keys.Connectivity.Partner: partners, - column_keys.Connectivity.Affinity: affinities, - column_keys.Connectivity.Area: areas, - column_keys.Connectivity.Connected: connected, - column_keys.Hierarchy.Parent: parent_id} - - rows.append(self.mutate_row(serializers.serialize_uint64(node_id), - val_dict, time_stamp=time_stamp)) - node_c += 1 - time_dict["creating_lv1_row"].append(time.time() - time_start_2) - - time_start_1 = time.time() - # Create parent node - rows.append(self.mutate_row(serializers.serialize_uint64(parent_id), - {column_keys.Hierarchy.Child: node_ids}, - time_stamp=time_stamp)) - - time_dict["creating_lv2_row"].append(time.time() - time_start_1) - time_start_1 = time.time() - - cce_layers = self.get_cross_chunk_edges_layer(parent_cross_edges) - u_cce_layers = np.unique(cce_layers) - - val_dict = {} - for cc_layer in u_cce_layers: - layer_cross_edges = parent_cross_edges[cce_layers == cc_layer] - - if len(layer_cross_edges) > 0: - val_dict[column_keys.Connectivity.CrossChunkEdge[cc_layer]] = \ - layer_cross_edges - - if len(val_dict) > 0: - rows.append(self.mutate_row(serializers.serialize_uint64(parent_id), - val_dict, time_stamp=time_stamp)) - node_c += 1 - - time_dict["adding_cross_edges"].append(time.time() - time_start_1) - - if len(rows) > 100000: - time_start_1 = time.time() - self.bulk_write(rows) - time_dict["writing"].append(time.time() - time_start_1) - - if len(rows) > 0: - time_start_1 = time.time() - self.bulk_write(rows) - time_dict["writing"].append(time.time() - time_start_1) - - if verbose: - self.logger.debug("Time creating rows: %.3fs for %d ccs with %d nodes" % - (time.time() - time_start, len(ccs), node_c)) - - for k in time_dict.keys(): - self.logger.debug("%s -- %.3fms for %d instances -- avg = %.3fms" % - (k, np.sum(time_dict[k])*1000, len(time_dict[k]), - np.mean(time_dict[k])*1000)) - - def add_layer(self, layer_id: int, - child_chunk_coords: Sequence[Sequence[int]], - time_stamp: Optional[datetime.datetime] = None, - verbose: bool = True, n_threads: int = 20) -> None: - """ Creates the abstract nodes for a given chunk in a given layer - - :param layer_id: int - :param child_chunk_coords: int array of length 3 - coords in chunk space - :param time_stamp: datetime - :param verbose: bool - :param n_threads: int - """ - def _read_subchunks_thread(chunk_coord): - # Get start and end key - x, y, z = chunk_coord - - columns = [column_keys.Hierarchy.Child] + \ - [column_keys.Connectivity.CrossChunkEdge[l] - for l in range(layer_id - 1, self.n_layers)] - range_read = self.range_read_chunk(layer_id - 1, x, y, z, - columns=columns) - - # Due to restarted jobs some nodes in the layer below might be - # duplicated. We want to ignore the earlier created node(s) because - # they belong to the failed job. We can find these duplicates only - # by comparing their children because each node has a unique id. - # However, we can use that more recently created nodes have higher - # segment ids (not true on root layer but we do not have that here. - # We are only interested in the latest version of any duplicated - # parents. - - # Deserialize row keys and store child with highest id for - # comparison - row_cell_dict = {} - segment_ids = [] - row_ids = [] - max_child_ids = [] - for row_id, row_data in range_read.items(): - segment_id = self.get_segment_id(row_id) - - cross_edge_columns = {k: v for (k, v) in row_data.items() - if k.family_id == self.cross_edge_family_id} - if cross_edge_columns: - row_cell_dict[row_id] = cross_edge_columns - - node_child_ids = row_data[column_keys.Hierarchy.Child][0].value - - max_child_ids.append(np.max(node_child_ids)) - segment_ids.append(segment_id) - row_ids.append(row_id) - - segment_ids = np.array(segment_ids, dtype=np.uint64) - row_ids = np.array(row_ids) - max_child_ids = np.array(max_child_ids, dtype=np.uint64) - - sorting = np.argsort(segment_ids)[::-1] - row_ids = row_ids[sorting] - max_child_ids = max_child_ids[sorting] - - counter = collections.defaultdict(int) - max_child_ids_occ_so_far = np.zeros(len(max_child_ids), - dtype=np.int) - for i_row in range(len(max_child_ids)): - max_child_ids_occ_so_far[i_row] = counter[max_child_ids[i_row]] - counter[max_child_ids[i_row]] += 1 - - # Filter last occurences (we inverted the list) of each node - m = max_child_ids_occ_so_far == 0 - row_ids = row_ids[m] - ll_node_ids.extend(row_ids) - - # Loop through nodes from this chunk - for row_id in row_ids: - if row_id in row_cell_dict: - cross_edge_dict[row_id] = {} - - cell_family = row_cell_dict[row_id] - - for l in range(layer_id - 1, self.n_layers): - row_key = column_keys.Connectivity.CrossChunkEdge[l] - if row_key in cell_family: - cross_edge_dict[row_id][l] = cell_family[row_key][0].value - - if int(layer_id - 1) in cross_edge_dict[row_id]: - atomic_cross_edges = cross_edge_dict[row_id][layer_id - 1] - - if len(atomic_cross_edges) > 0: - atomic_partner_id_dict[row_id] = \ - atomic_cross_edges[:, 1] - - new_pairs = zip(atomic_cross_edges[:, 0], - [row_id] * len(atomic_cross_edges)) - atomic_child_id_dict_pairs.extend(new_pairs) - - def _resolve_cross_chunk_edges_thread(args) -> None: - start, end = args - - for i_child_key, child_key in\ - enumerate(atomic_partner_id_dict_keys[start: end]): - this_atomic_partner_ids = atomic_partner_id_dict[child_key] - - partners = {atomic_child_id_dict[atomic_cross_id] - for atomic_cross_id in this_atomic_partner_ids - if atomic_child_id_dict[atomic_cross_id] != 0} - - if len(partners) > 0: - partners = np.array(list(partners), dtype=np.uint64)[:, None] - - this_ids = np.array([child_key] * len(partners), - dtype=np.uint64)[:, None] - these_edges = np.concatenate([this_ids, partners], axis=1) - - edge_ids.extend(these_edges) - - def _write_out_connected_components(args) -> None: - start, end = args - - # Collect cc info - parent_layer_ids = range(layer_id, self.n_layers + 1) - cc_connections = {l: [] for l in parent_layer_ids} - for i_cc, cc in enumerate(ccs[start: end]): - node_ids = unique_graph_ids[cc] - - parent_cross_edges = collections.defaultdict(list) - - # Collect row info for nodes that are in this chunk - for node_id in node_ids: - if node_id in cross_edge_dict: - # Extract edges relevant to this node - for l in range(layer_id, self.n_layers): - if l in cross_edge_dict[node_id] and \ - len(cross_edge_dict[node_id][l]) > 0: - parent_cross_edges[l].append(cross_edge_dict[node_id][l]) - - if self.use_skip_connections and len(node_ids) == 1: - for l in parent_layer_ids: - if l == self.n_layers or len(parent_cross_edges[l]) > 0: - cc_connections[l].append([node_ids, - parent_cross_edges]) - break - else: - cc_connections[layer_id].append([node_ids, - parent_cross_edges]) - - # Write out cc info - rows = [] - - # Iterate through layers - for parent_layer_id in parent_layer_ids: - if len(cc_connections[parent_layer_id]) == 0: - continue - - parent_chunk_id = parent_chunk_id_dict[parent_layer_id] - reserved_parent_ids = self.get_unique_node_id_range( - parent_chunk_id, step=len(cc_connections[parent_layer_id])) - - for i_cc, cc_info in enumerate(cc_connections[parent_layer_id]): - node_ids, parent_cross_edges = cc_info - - parent_id = reserved_parent_ids[i_cc] - val_dict = {column_keys.Hierarchy.Parent: parent_id} - - for node_id in node_ids: - rows.append(self.mutate_row( - serializers.serialize_uint64(node_id), - val_dict, time_stamp=time_stamp)) - - val_dict = {column_keys.Hierarchy.Child: node_ids} - for l in range(parent_layer_id, self.n_layers): - if l in parent_cross_edges and len(parent_cross_edges[l]) > 0: - val_dict[column_keys.Connectivity.CrossChunkEdge[l]] = \ - np.concatenate(parent_cross_edges[l]) - - rows.append( - self.mutate_row(serializers.serialize_uint64(parent_id), - val_dict, time_stamp=time_stamp)) - - if len(rows) > 100000: - self.bulk_write(rows) - rows = [] - - if len(rows) > 0: - self.bulk_write(rows) - - if time_stamp is None: - time_stamp = datetime.datetime.utcnow() - - if time_stamp.tzinfo is None: - time_stamp = UTC.localize(time_stamp) - - # Comply to resolution of BigTables TimeRange - time_stamp = get_google_compatible_time_stamp(time_stamp, - round_up=False) - - # 1 -------------------------------------------------------------------- - # The first part is concerned with reading data from the child nodes - # of this layer and pre-processing it for the second part - - time_start = time.time() - - atomic_partner_id_dict = {} - cross_edge_dict = {} - atomic_child_id_dict_pairs = [] - ll_node_ids = [] - - multi_args = child_chunk_coords - n_jobs = np.min([n_threads, len(multi_args)]) - - if n_jobs > 0: - mu.multithread_func(_read_subchunks_thread, multi_args, - n_threads=n_jobs) - - d = dict(atomic_child_id_dict_pairs) - atomic_child_id_dict = collections.defaultdict(np.uint64, d) - ll_node_ids = np.array(ll_node_ids, dtype=np.uint64) - - if verbose: - self.logger.debug("Time iterating through subchunks: %.3fs" % - (time.time() - time_start)) - time_start = time.time() - - # Extract edges from remaining cross chunk edges - # and maintain unused cross chunk edges - edge_ids = [] - # u_atomic_child_ids = np.unique(atomic_child_ids) - atomic_partner_id_dict_keys = \ - np.array(list(atomic_partner_id_dict.keys()), dtype=np.uint64) - - if n_threads > 1: - n_jobs = n_threads * 3 # Heuristic - else: - n_jobs = 1 - - n_jobs = np.min([n_jobs, len(atomic_partner_id_dict_keys)]) - - if n_jobs > 0: - spacing = np.linspace(0, len(atomic_partner_id_dict_keys), - n_jobs+1).astype(np.int) - starts = spacing[:-1] - ends = spacing[1:] - - multi_args = list(zip(starts, ends)) - - mu.multithread_func(_resolve_cross_chunk_edges_thread, multi_args, - n_threads=n_threads) - - if verbose: - self.logger.debug("Time resolving cross chunk edges: %.3fs" % - (time.time() - time_start)) - time_start = time.time() - - # 2 -------------------------------------------------------------------- - # The second part finds connected components, writes the parents to - # BigTable and updates the childs - - # Make parent id creation easier - x, y, z = np.min(child_chunk_coords, axis=0) // self.fan_out - chunk_id = self.get_chunk_id(layer=layer_id, x=x, y=y, z=z) - - parent_chunk_id_dict = self.get_parent_chunk_id_dict(chunk_id) - - # Extract connected components - isolated_node_mask = ~np.in1d(ll_node_ids, np.unique(edge_ids)) - add_node_ids = ll_node_ids[isolated_node_mask].squeeze() - add_edge_ids = np.vstack([add_node_ids, add_node_ids]).T - edge_ids.extend(add_edge_ids) - - graph, _, _, unique_graph_ids = flatgraph_utils.build_gt_graph( - edge_ids, make_directed=True) - - ccs = flatgraph_utils.connected_components(graph) - - if verbose: - self.logger.debug("Time connected components: %.3fs" % - (time.time() - time_start)) - time_start = time.time() - - # Add rows for nodes that are in this chunk - # a connected component at a time - if n_threads > 1: - n_jobs = n_threads * 3 # Heuristic - else: - n_jobs = 1 - - n_jobs = np.min([n_jobs, len(ccs)]) - - spacing = np.linspace(0, len(ccs), n_jobs+1).astype(np.int) - starts = spacing[:-1] - ends = spacing[1:] - - multi_args = list(zip(starts, ends)) - - mu.multithread_func(_write_out_connected_components, multi_args, - n_threads=n_threads) - - if verbose: - self.logger.debug("Time writing %d connected components in layer %d: %.3fs" % - (len(ccs), layer_id, time.time() - time_start)) - - def get_atomic_cross_edge_dict(self, node_id: np.uint64, - layer_ids: Sequence[int] = None): - """ Extracts all atomic cross edges and serves them as a dictionary - - :param node_id: np.uint64 - :param layer_ids: list of ints - :return: dict - """ - if isinstance(layer_ids, int): - layer_ids = [layer_ids] - - if layer_ids is None: - layer_ids = list(range(2, self.n_layers)) - - if not layer_ids: - return {} - - columns = [column_keys.Connectivity.CrossChunkEdge[l] for l in layer_ids] - - row = self.read_node_id_row(node_id, columns=columns) - - if not row: - return {} - - atomic_cross_edges = {} - - for l in layer_ids: - column = column_keys.Connectivity.CrossChunkEdge[l] - - atomic_cross_edges[l] = [] - - if column in row: - atomic_cross_edges[l] = row[column][0].value - - return atomic_cross_edges - - def get_parents(self, node_ids: Sequence[np.uint64], - get_only_relevant_parents: bool = True, - time_stamp: Optional[datetime.datetime] = None): - """ Acquires parents of a node at a specific time stamp - - :param node_ids: list of uint64 - :param get_only_relevant_parents: bool - True: return single parent according to time_stamp - False: return n x 2 list of all parents - ((parent_id, time_stamp), ...) - :param time_stamp: datetime or None - :return: uint64 or None - """ - if time_stamp is None: - time_stamp = datetime.datetime.utcnow() - - if time_stamp.tzinfo is None: - time_stamp = UTC.localize(time_stamp) - - parent_rows = self.read_node_id_rows(node_ids=node_ids, - columns=column_keys.Hierarchy.Parent, - end_time=time_stamp, - end_time_inclusive=True) - - if not parent_rows: - return None - - if get_only_relevant_parents: - return np.array([parent_rows[node_id][0].value - for node_id in node_ids]) - - parents = [] - for node_id in node_ids: - parents.append([(p.value, p.timestamp) - for p in parent_rows[node_id]]) - - return parents - - def get_parent(self, node_id: np.uint64, - get_only_relevant_parent: bool = True, - time_stamp: Optional[datetime.datetime] = None) -> Union[ - List[Tuple[np.uint64, datetime.datetime]], np.uint64]: - """ Acquires parent of a node at a specific time stamp - - :param node_id: uint64 - :param get_only_relevant_parent: bool - True: return single parent according to time_stamp - False: return n x 2 list of all parents - ((parent_id, time_stamp), ...) - :param time_stamp: datetime or None - :return: uint64 or None - """ - if time_stamp is None: - time_stamp = datetime.datetime.utcnow() - - if time_stamp.tzinfo is None: - time_stamp = UTC.localize(time_stamp) - - parents = self.read_node_id_row(node_id, - columns=column_keys.Hierarchy.Parent, - end_time=time_stamp, - end_time_inclusive=True) - - if not parents: - return None - - if get_only_relevant_parent: - return parents[0].value - - return [(p.value, p.timestamp) for p in parents] - - def get_children(self, node_id: Union[Iterable[np.uint64], np.uint64], - flatten: bool = False) -> Union[Dict[np.uint64, np.ndarray], np.ndarray]: - """Returns children for the specified NodeID or NodeIDs - - :param node_id: The NodeID or NodeIDs for which to retrieve children - :type node_id: Union[Iterable[np.uint64], np.uint64] - :param flatten: If True, combine all children into a single array, else generate a map - of input ``node_id`` to their respective children. - :type flatten: bool, default is True - :return: Children for each requested NodeID. The return type depends on the ``flatten`` - parameter. - :rtype: Union[Dict[np.uint64, np.ndarray], np.ndarray] - """ - if np.isscalar(node_id): - children = self.read_node_id_row(node_id=node_id, columns=column_keys.Hierarchy.Child) - if not children: - return np.empty(0, dtype=basetypes.NODE_ID) - return children[0].value - else: - children = self.read_node_id_rows(node_ids=node_id, columns=column_keys.Hierarchy.Child) - if flatten: - if not children: - return np.empty(0, dtype=basetypes.NODE_ID) - return np.concatenate([x[0].value for x in children.values()]) - return {x: children[x][0].value - if x in children else np.empty(0, dtype=basetypes.NODE_ID) - for x in node_id} - - def get_latest_roots(self, time_stamp: Optional[datetime.datetime] = get_max_time(), - n_threads: int = 1) -> Sequence[np.uint64]: - """ Reads _all_ root ids - - :param time_stamp: datetime.datetime - :param n_threads: int - :return: array of np.uint64 - """ - - return chunkedgraph_comp.get_latest_roots(self, time_stamp=time_stamp, - n_threads=n_threads) - - def get_delta_roots(self, - time_stamp_start: datetime.datetime, - time_stamp_end: Optional[datetime.datetime] = None, - min_seg_id: int =1, - n_threads: int = 1) -> Sequence[np.uint64]: - """ Returns root ids that have expired or have been created between two timestamps - - :param time_stamp_start: datetime.datetime - starting timestamp to return deltas from - :param time_stamp_end: datetime.datetime - ending timestamp to return deltasfrom - :param min_seg_id: int (default=1) - only search from this seg_id and higher (note not a node_id.. use get_seg_id) - :param n_threads: int (default=1) - number of threads to use in performing search - :return new_ids, expired_ids: np.arrays of np.uint64 - new_ids is an array of root_ids for roots that were created after time_stamp_start - and are still current as of time_stamp_end. - expired_ids is list of node_id's for roots the expired after time_stamp_start - but before time_stamp_end. - """ - - return chunkedgraph_comp.get_delta_roots(self, time_stamp_start=time_stamp_start, - time_stamp_end=time_stamp_end, - min_seg_id=min_seg_id, - n_threads=n_threads) - - def get_roots(self, node_ids: Sequence[np.uint64], - time_stamp: Optional[datetime.datetime] = None, - stop_layer: int = None, n_tries: int = 1): - """ Takes node ids and returns the associated agglomeration ids - - :param node_ids: list of uint64 - :param time_stamp: None or datetime - :return: np.uint64 - """ - if time_stamp is None: - time_stamp = datetime.datetime.utcnow() - - if time_stamp.tzinfo is None: - time_stamp = UTC.localize(time_stamp) - - # Comply to resolution of BigTables TimeRange - time_stamp = get_google_compatible_time_stamp(time_stamp, - round_up=False) - - parent_ids = np.array(node_ids) - - if stop_layer is not None: - stop_layer = min(self.n_layers, stop_layer) - else: - stop_layer = self.n_layers - - node_mask = np.ones(len(node_ids), dtype=np.bool) - node_mask[self.get_chunk_layers(node_ids) >= stop_layer] = False - for i_try in range(n_tries): - parent_ids = np.array(node_ids) - - for i_layer in range(int(stop_layer + 1)): - temp_parent_ids = self.get_parents(parent_ids[node_mask], - time_stamp=time_stamp) - - if temp_parent_ids is None: - break - else: - parent_ids[node_mask] = temp_parent_ids - - node_mask[self.get_chunk_layers(parent_ids) >= stop_layer] = False - if np.all(~node_mask): - break - - if np.all(self.get_chunk_layers(parent_ids) >= stop_layer): - break - else: - time.sleep(.5) - - return parent_ids - - def get_root(self, node_id: np.uint64, - time_stamp: Optional[datetime.datetime] = None, - get_all_parents=False, stop_layer: int = None, - n_tries: int = 1) -> Union[List[np.uint64], np.uint64]: - """ Takes a node id and returns the associated agglomeration ids - - :param node_id: uint64 - :param time_stamp: None or datetime - :return: np.uint64 - """ - if time_stamp is None: - time_stamp = datetime.datetime.utcnow() - - if time_stamp.tzinfo is None: - time_stamp = UTC.localize(time_stamp) - - # Comply to resolution of BigTables TimeRange - time_stamp = get_google_compatible_time_stamp(time_stamp, - round_up=False) - - parent_id = node_id - all_parent_ids = [] - - if stop_layer is not None: - stop_layer = min(self.n_layers, stop_layer) - else: - stop_layer = self.n_layers - - for i_try in range(n_tries): - parent_id = node_id - - for i_layer in range(self.get_chunk_layer(node_id), - int(stop_layer + 1)): - - temp_parent_id = self.get_parent(parent_id, - time_stamp=time_stamp) - - if temp_parent_id is None: - break - else: - parent_id = temp_parent_id - all_parent_ids.append(parent_id) - - if self.get_chunk_layer(parent_id) >= stop_layer: - break - - if self.get_chunk_layer(parent_id) >= stop_layer: - break - else: - time.sleep(.5) - - if self.get_chunk_layer(parent_id) < stop_layer: - raise Exception("Cannot find root id {}, {}".format(node_id, - time_stamp)) - - if get_all_parents: - return np.array(all_parent_ids) - else: - return parent_id - - def get_all_parents_dict(self, node_id: np.uint64, - time_stamp: Optional[datetime.datetime] = None - ) -> dict: - """ Takes a node id and returns all parents and parents' parents up to - the top - - :param node_id: uint64 - :param time_stamp: None or datetime - :return: dict - """ - parent_ids = self.get_root(node_id=node_id, time_stamp=time_stamp, - get_all_parents=True) - parent_id_layers = self.get_chunk_layers(parent_ids) - return dict(zip(parent_id_layers, parent_ids)) - - def lock_root_loop(self, root_ids: Sequence[np.uint64], - operation_id: np.uint64, max_tries: int = 1, - waittime_s: float = 0.5) -> Tuple[bool, np.ndarray]: - """ Attempts to lock multiple roots at the same time - - :param root_ids: list of uint64 - :param operation_id: uint64 - :param max_tries: int - :param waittime_s: float - :return: bool, list of uint64s - success, latest root ids - """ - - i_try = 0 - while i_try < max_tries: - lock_acquired = False - - # Collect latest root ids - new_root_ids: List[np.uint64] = [] - for i_root_id in range(len(root_ids)): - future_root_ids = self.get_future_root_ids(root_ids[i_root_id]) - - if len(future_root_ids) == 0: - new_root_ids.append(root_ids[i_root_id]) - else: - new_root_ids.extend(future_root_ids) - - # Attempt to lock all latest root ids - root_ids = np.unique(new_root_ids) - - for i_root_id in range(len(root_ids)): - - self.logger.debug("operation id: %d - root id: %d" % - (operation_id, root_ids[i_root_id])) - lock_acquired = self.lock_single_root(root_ids[i_root_id], - operation_id) - - # Roll back locks if one root cannot be locked - if not lock_acquired: - for j_root_id in range(len(root_ids)): - self.unlock_root(root_ids[j_root_id], operation_id) - break - - if lock_acquired: - return True, root_ids - - time.sleep(waittime_s) - i_try += 1 - self.logger.debug(f"Try {i_try}") - - return False, root_ids - - def lock_single_root(self, root_id: np.uint64, operation_id: np.uint64 - ) -> bool: - """ Attempts to lock the latest version of a root node - - :param root_id: uint64 - :param operation_id: uint64 - an id that is unique to the process asking to lock the root node - :return: bool - success - """ - - operation_id_b = serializers.serialize_uint64(operation_id) - - lock_column = column_keys.Concurrency.Lock - new_parents_column = column_keys.Hierarchy.NewParent - - # Build a column filter which tests if a lock was set (== lock column - # exists) and if it is still valid (timestamp younger than - # LOCK_EXPIRED_TIME_DELTA) and if there is no new parent (== new_parents - # exists) - - time_cutoff = datetime.datetime.utcnow() - LOCK_EXPIRED_TIME_DELTA - - # Comply to resolution of BigTables TimeRange - time_cutoff -= datetime.timedelta( - microseconds=time_cutoff.microsecond % 1000) - - time_filter = TimestampRangeFilter(TimestampRange(start=time_cutoff)) - - # lock_key_filter = ColumnQualifierRegexFilter(lock_column.key) - # new_parents_key_filter = ColumnQualifierRegexFilter(new_parents_column.key) - - lock_key_filter = ColumnRangeFilter( - column_family_id=lock_column.family_id, - start_column=lock_column.key, - end_column=lock_column.key, - inclusive_start=True, - inclusive_end=True) - - new_parents_key_filter = ColumnRangeFilter( - column_family_id=new_parents_column.family_id, - start_column=new_parents_column.key, - end_column=new_parents_column.key, - inclusive_start=True, - inclusive_end=True) - - # Combine filters together - chained_filter = RowFilterChain([time_filter, lock_key_filter]) - combined_filter = ConditionalRowFilter( - base_filter=chained_filter, - true_filter=PassAllFilter(True), - false_filter=new_parents_key_filter) - - # Get conditional row using the chained filter - root_row = self.table.row(serializers.serialize_uint64(root_id), - filter_=combined_filter) - - # Set row lock if condition returns no results (state == False) - time_stamp = datetime.datetime.utcnow() - - # Comply to resolution of BigTables TimeRange - time_stamp = get_google_compatible_time_stamp(time_stamp, - round_up=False) - - root_row.set_cell(lock_column.family_id, lock_column.key, operation_id_b, state=False, - timestamp=time_stamp) - - # The lock was acquired when set_cell returns False (state) - lock_acquired = not root_row.commit() - - if not lock_acquired: - row = self.read_node_id_row(root_id, columns=lock_column) - - l_operation_ids = [cell.value for cell in row] - self.logger.debug(f"Locked operation ids: {l_operation_ids}") - - return lock_acquired - - def unlock_root(self, root_id: np.uint64, operation_id: np.uint64) -> bool: - """ Unlocks a root - - This is mainly used for cases where multiple roots need to be locked and - locking was not sucessful for all of them - - :param root_id: np.uint64 - :param operation_id: uint64 - an id that is unique to the process asking to lock the root node - :return: bool - success - """ - lock_column = column_keys.Concurrency.Lock - operation_id_b = lock_column.serialize(operation_id) - - # Build a column filter which tests if a lock was set (== lock column - # exists) and if it is still valid (timestamp younger than - # LOCK_EXPIRED_TIME_DELTA) and if the given operation_id is still - # the active lock holder - - time_cutoff = datetime.datetime.utcnow() - LOCK_EXPIRED_TIME_DELTA - - # Comply to resolution of BigTables TimeRange - time_cutoff -= datetime.timedelta( - microseconds=time_cutoff.microsecond % 1000) - - time_filter = TimestampRangeFilter(TimestampRange(start=time_cutoff)) - - # column_key_filter = ColumnQualifierRegexFilter(lock_column.key) - # value_filter = ColumnQualifierRegexFilter(operation_id_b) - - column_key_filter = ColumnRangeFilter( - column_family_id=lock_column.family_id, - start_column=lock_column.key, - end_column=lock_column.key, - inclusive_start=True, - inclusive_end=True) - - value_filter = ValueRangeFilter( - start_value=operation_id_b, - end_value=operation_id_b, - inclusive_start=True, - inclusive_end=True) - - # Chain these filters together - chained_filter = RowFilterChain([time_filter, column_key_filter, - value_filter]) - - # Get conditional row using the chained filter - root_row = self.table.row(serializers.serialize_uint64(root_id), - filter_=chained_filter) - - # Delete row if conditions are met (state == True) - root_row.delete_cell(lock_column.family_id, lock_column.key, state=True) - - return root_row.commit() - - def check_and_renew_root_locks(self, root_ids: Iterable[np.uint64], - operation_id: np.uint64) -> bool: - """ Tests if the roots are locked with the provided operation_id and - renews the lock to reset the time_stam - - This is mainly used before executing a bulk write - - :param root_ids: uint64 - :param operation_id: uint64 - an id that is unique to the process asking to lock the root node - :return: bool - success - """ - - for root_id in root_ids: - if not self.check_and_renew_root_lock_single(root_id, operation_id): - self.logger.warning(f"check_and_renew_root_locks failed - {root_id}") - return False - - return True - - def check_and_renew_root_lock_single(self, root_id: np.uint64, - operation_id: np.uint64) -> bool: - """ Tests if the root is locked with the provided operation_id and - renews the lock to reset the time_stam - - This is mainly used before executing a bulk write - - :param root_id: uint64 - :param operation_id: uint64 - an id that is unique to the process asking to lock the root node - :return: bool - success - """ - lock_column = column_keys.Concurrency.Lock - new_parents_column = column_keys.Hierarchy.NewParent - - operation_id_b = lock_column.serialize(operation_id) - - # Build a column filter which tests if a lock was set (== lock column - # exists) and if the given operation_id is still the active lock holder - # and there is no new parent (== new_parents column exists). The latter - # is not necessary but we include it as a backup to prevent things - # from going really bad. - - # column_key_filter = ColumnQualifierRegexFilter(lock_column.key) - # value_filter = ColumnQualifierRegexFilter(operation_id_b) - - column_key_filter = ColumnRangeFilter( - column_family_id=lock_column.family_id, - start_column=lock_column.key, - end_column=lock_column.key, - inclusive_start=True, - inclusive_end=True) - - value_filter = ValueRangeFilter( - start_value=operation_id_b, - end_value=operation_id_b, - inclusive_start=True, - inclusive_end=True) - - new_parents_key_filter = ColumnRangeFilter( - column_family_id=self.family_id, - start_column=new_parents_column.key, - end_column=new_parents_column.key, - inclusive_start=True, - inclusive_end=True) - - # Chain these filters together - chained_filter = RowFilterChain([column_key_filter, value_filter]) - combined_filter = ConditionalRowFilter( - base_filter=chained_filter, - true_filter=new_parents_key_filter, - false_filter=PassAllFilter(True)) - - # Get conditional row using the chained filter - root_row = self.table.row(serializers.serialize_uint64(root_id), - filter_=combined_filter) - - # Set row lock if condition returns a result (state == True) - root_row.set_cell(lock_column.family_id, lock_column.key, operation_id_b, state=False) - - # The lock was acquired when set_cell returns True (state) - lock_acquired = not root_row.commit() - - return lock_acquired - - def read_consolidated_lock_timestamp(self, root_ids: Sequence[np.uint64], - operation_ids: Sequence[np.uint64] - ) -> Union[datetime.datetime, None]: - """ Returns minimum of many lock timestamps - - :param root_ids: np.ndarray - :param operation_ids: np.ndarray - :return: - """ - time_stamps = [] - for root_id, operation_id in zip(root_ids, operation_ids): - time_stamp = self.read_lock_timestamp(root_id, operation_id) - - if time_stamp is None: - return None - - time_stamps.append(time_stamp) - - if len(time_stamps) == 0: - return None - - return np.min(time_stamps) - - def read_lock_timestamp(self, root_id: np.uint64, operation_id: np.uint64 - ) -> Union[datetime.datetime, None]: - """ Reads timestamp from lock row to get a consistent timestamp across - multiple nodes / pods - - :param root_id: np.uint64 - :param operation_id: np.uint64 - Checks whether the root_id is actually locked with this operation_id - :return: datetime.datetime or None - """ - row = self.read_node_id_row(root_id, - columns=column_keys.Concurrency.Lock) - - if len(row) == 0: - self.logger.warning(f"No lock found for {root_id}") - return None - - if row[0].value != operation_id: - self.logger.warning(f"{root_id} not locked with {operation_id}") - return None - - return row[0].timestamp - - def get_latest_root_id(self, root_id: np.uint64) -> np.ndarray: - """ Returns the latest root id associated with the provided root id - - :param root_id: uint64 - :return: list of uint64s - """ - - id_working_set = [root_id] - column = column_keys.Hierarchy.NewParent - latest_root_ids = [] - - while len(id_working_set) > 0: - next_id = id_working_set[0] - del(id_working_set[0]) - row = self.read_node_id_row(next_id, columns=column) - - # Check if a new root id was attached to this root id - if row: - id_working_set.extend(row[0].value) - else: - latest_root_ids.append(next_id) - - return np.unique(latest_root_ids) - - def get_future_root_ids(self, root_id: np.uint64, - time_stamp: Optional[datetime.datetime] = - get_max_time())-> np.ndarray: - """ Returns all future root ids emerging from this root - - This search happens in a monotic fashion. At no point are past root - ids of future root ids taken into account. - - :param root_id: np.uint64 - :param time_stamp: None or datetime - restrict search to ids created before this time_stamp - None=search whole future - :return: array of uint64 - """ - if time_stamp.tzinfo is None: - time_stamp = UTC.localize(time_stamp) - - # Comply to resolution of BigTables TimeRange - time_stamp = get_google_compatible_time_stamp(time_stamp, - round_up=False) - - id_history = [] - - next_ids = [root_id] - while len(next_ids): - temp_next_ids = [] - - for next_id in next_ids: - row = self.read_node_id_row(next_id, columns=[column_keys.Hierarchy.NewParent, - column_keys.Hierarchy.Child]) - if column_keys.Hierarchy.NewParent in row: - ids = row[column_keys.Hierarchy.NewParent][0].value - row_time_stamp = row[column_keys.Hierarchy.NewParent][0].timestamp - elif column_keys.Hierarchy.Child in row: - ids = None - row_time_stamp = row[column_keys.Hierarchy.Child][0].timestamp - else: - raise cg_exceptions.ChunkedGraphError("Error retrieving future root ID of %s" % next_id) - - if row_time_stamp < time_stamp: - if ids is not None: - temp_next_ids.extend(ids) - - if next_id != root_id: - id_history.append(next_id) - - next_ids = temp_next_ids - - return np.unique(np.array(id_history, dtype=np.uint64)) - - def get_past_root_ids(self, root_id: np.uint64, - time_stamp: Optional[datetime.datetime] = - get_min_time()) -> np.ndarray: - """ Returns all future root ids emerging from this root - - This search happens in a monotic fashion. At no point are future root - ids of past root ids taken into account. - - :param root_id: np.uint64 - :param time_stamp: None or datetime - restrict search to ids created after this time_stamp - None=search whole future - :return: array of uint64 - """ - if time_stamp.tzinfo is None: - time_stamp = UTC.localize(time_stamp) - - # Comply to resolution of BigTables TimeRange - time_stamp = get_google_compatible_time_stamp(time_stamp, - round_up=False) - - id_history = [] - - next_ids = [root_id] - while len(next_ids): - temp_next_ids = [] - - for next_id in next_ids: - row = self.read_node_id_row(next_id, columns=[column_keys.Hierarchy.FormerParent, - column_keys.Hierarchy.Child]) - if column_keys.Hierarchy.FormerParent in row: - ids = row[column_keys.Hierarchy.FormerParent][0].value - row_time_stamp = row[column_keys.Hierarchy.FormerParent][0].timestamp - elif column_keys.Hierarchy.Child in row: - ids = None - row_time_stamp = row[column_keys.Hierarchy.Child][0].timestamp - else: - raise cg_exceptions.ChunkedGraphError("Error retrieving past root ID of %s" % next_id) - - if row_time_stamp > time_stamp: - if ids is not None: - temp_next_ids.extend(ids) - - if next_id != root_id: - id_history.append(next_id) - - next_ids = temp_next_ids - - return np.unique(np.array(id_history, dtype=np.uint64)) - - def get_root_id_history(self, root_id: np.uint64, - time_stamp_past: - Optional[datetime.datetime] = get_min_time(), - time_stamp_future: - Optional[datetime.datetime] = get_max_time() - ) -> np.ndarray: - """ Returns all future root ids emerging from this root - - This search happens in a monotic fashion. At no point are future root - ids of past root ids or past root ids of future root ids taken into - account. - - :param root_id: np.uint64 - :param time_stamp_past: None or datetime - restrict search to ids created after this time_stamp - None=search whole future - :param time_stamp_future: None or datetime - restrict search to ids created before this time_stamp - None=search whole future - :return: array of uint64 - """ - past_ids = self.get_past_root_ids(root_id=root_id, - time_stamp=time_stamp_past) - future_ids = self.get_future_root_ids(root_id=root_id, - time_stamp=time_stamp_future) - - history_ids = np.concatenate([past_ids, - np.array([root_id], dtype=np.uint64), - future_ids]) - - return history_ids - - def get_change_log(self, root_id: np.uint64, - correct_for_wrong_coord_type: bool = True, - time_stamp_past: Optional[datetime.datetime] = get_min_time() - ) -> dict: - """ Returns all past root ids for this root - - This search happens in a monotic fashion. At no point are future root - ids of past root ids taken into account. - - :param root_id: np.uint64 - :param correct_for_wrong_coord_type: bool - pinky100? --> True - :param time_stamp_past: None or datetime - restrict search to ids created after this time_stamp - None=search whole past - :return: past ids, merge sv ids, merge edge coords, split sv ids - """ - if time_stamp_past.tzinfo is None: - time_stamp_past = UTC.localize(time_stamp_past) - - id_history = [] - merge_history = [] - merge_history_edges = [] - split_history = [] - - next_ids = [root_id] - while len(next_ids): - temp_next_ids = [] - former_parent_col = column_keys.Hierarchy.FormerParent - row_dict = self.read_node_id_rows(node_ids=next_ids, - columns=[former_parent_col]) - - for row in row_dict.values(): - if column_keys.Hierarchy.FormerParent in row: - if time_stamp_past > row[former_parent_col][0].timestamp: - continue - - ids = row[former_parent_col][0].value - - lock_col = column_keys.Concurrency.Lock - former_row = self.read_node_id_row(ids[0], - columns=[lock_col]) - operation_id = former_row[lock_col][0].value - log_row = self.read_log_row(operation_id) - is_merge = column_keys.OperationLogs.AddedEdge in log_row - - for id_ in ids: - if id_ in id_history: - continue - - id_history.append(id_) - temp_next_ids.append(id_) - - if is_merge: - added_edges = log_row[column_keys.OperationLogs.AddedEdge] - merge_history.append(added_edges) - - coords = [log_row[column_keys.OperationLogs.SourceCoordinate], - log_row[column_keys.OperationLogs.SinkCoordinate]] - - if correct_for_wrong_coord_type: - # A little hack because we got the datatype wrong... - coords = [np.frombuffer(coords[0]), - np.frombuffer(coords[1])] - coords *= self.segmentation_resolution - - merge_history_edges.append(coords) - - if not is_merge: - removed_edges = log_row[column_keys.OperationLogs.RemovedEdge] - split_history.append(removed_edges) - else: - continue - - next_ids = temp_next_ids - - return {"past_ids": np.unique(np.array(id_history, dtype=np.uint64)), - "merge_edges": np.array(merge_history), - "merge_edge_coords": np.array(merge_history_edges), - "split_edges": np.array(split_history)} - - def normalize_bounding_box(self, - bounding_box: Optional[Sequence[Sequence[int]]], - bb_is_coordinate: bool) -> \ - Union[Sequence[Sequence[int]], None]: - if bounding_box is None: - return None - - if bb_is_coordinate: - bounding_box[0] = self.get_chunk_coordinates_from_vol_coordinates( - bounding_box[0][0], bounding_box[0][1], bounding_box[0][2], - resolution=self.cv.resolution, ceil=False) - bounding_box[1] = self.get_chunk_coordinates_from_vol_coordinates( - bounding_box[1][0], bounding_box[1][1], bounding_box[1][2], - resolution=self.cv.resolution, ceil=True) - return bounding_box - else: - return np.array(bounding_box, dtype=np.int) - - def _get_subgraph_higher_layer_nodes( - self, node_id: np.uint64, - bounding_box: Optional[Sequence[Sequence[int]]], - return_layers: Sequence[int], - verbose: bool): - - def _get_subgraph_higher_layer_nodes_threaded( - node_ids: Iterable[np.uint64]) -> List[np.uint64]: - children = self.get_children(node_ids, flatten=True) - - if len(children) > 0 and bounding_box is not None: - chunk_coordinates = np.array([self.get_chunk_coordinates(c) for c in children]) - child_layers = self.get_chunk_layers(children) - adapt_child_layers = child_layers - 2 - adapt_child_layers[adapt_child_layers < 0] = 0 - - bounding_box_layer = bounding_box[None] / \ - (self.fan_out ** adapt_child_layers)[:, None, None] - - bound_check = np.array([ - np.all(chunk_coordinates < bounding_box_layer[:, 1], axis=1), - np.all(chunk_coordinates + 1 > bounding_box_layer[:, 0], axis=1)]).T - - bound_check_mask = np.all(bound_check, axis=1) - children = children[bound_check_mask] - - return children - - if bounding_box is not None: - bounding_box = np.array(bounding_box) - - layer = self.get_chunk_layer(node_id) - assert layer > 1 - - nodes_per_layer = {} - child_ids = np.array([node_id], dtype=np.uint64) - stop_layer = max(2, np.min(return_layers)) - - if layer in return_layers: - nodes_per_layer[layer] = child_ids - - if verbose: - time_start = time.time() - - while layer > stop_layer: - # Use heuristic to guess the optimal number of threads - child_id_layers = self.get_chunk_layers(child_ids) - this_layer_m = child_id_layers == layer - this_layer_child_ids = child_ids[this_layer_m] - next_layer_child_ids = child_ids[~this_layer_m] - - n_child_ids = len(child_ids) - this_n_threads = np.min([int(n_child_ids // 50000) + 1, mu.n_cpus]) - - child_ids = np.fromiter(chain.from_iterable(mu.multithread_func( - _get_subgraph_higher_layer_nodes_threaded, - np.array_split(this_layer_child_ids, this_n_threads), - n_threads=this_n_threads, debug=this_n_threads == 1)), np.uint64) - child_ids = np.concatenate([child_ids, next_layer_child_ids]) - - if verbose: - self.logger.debug("Layer %d: %.3fms for %d children with %d threads" % - (layer, (time.time() - time_start) * 1000, n_child_ids, - this_n_threads)) - time_start = time.time() - - layer -= 1 - if layer in return_layers: - nodes_per_layer[layer] = child_ids - - return nodes_per_layer - - def get_subgraph_edges(self, agglomeration_id: np.uint64, - bounding_box: Optional[Sequence[Sequence[int]]] = None, - bb_is_coordinate: bool = False, - connected_edges=True, - verbose: bool = True - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ Return all atomic edges between supervoxels belonging to the - specified agglomeration ID within the defined bounding box - - :param agglomeration_id: int - :param bounding_box: [[x_l, y_l, z_l], [x_h, y_h, z_h]] - :param bb_is_coordinate: bool - :param verbose: bool - :return: edge list - """ - - def _get_subgraph_layer2_edges(node_ids) -> \ - Tuple[List[np.ndarray], List[np.float32], List[np.uint64]]: - return self.get_subgraph_chunk(node_ids, - connected_edges=connected_edges, - time_stamp=time_stamp) - - time_stamp = self.read_node_id_row(agglomeration_id, - columns=column_keys.Hierarchy.Child)[0].timestamp - - bounding_box = self.normalize_bounding_box(bounding_box, bb_is_coordinate) - - # Layer 3+ - child_ids = self._get_subgraph_higher_layer_nodes( - node_id=agglomeration_id, bounding_box=bounding_box, - return_layers=[2], verbose=verbose)[2] - - # Layer 2 - if verbose: - time_start = time.time() - - - child_chunk_ids = self.get_chunk_ids_from_node_ids(child_ids) - u_ccids = np.unique(child_chunk_ids) - - child_blocks = [] - # Make blocks of child ids that are in the same chunk - for u_ccid in u_ccids: - child_blocks.append(child_ids[child_chunk_ids == u_ccid]) - - n_child_ids = len(child_ids) - this_n_threads = np.min([int(n_child_ids // 50000) + 1, mu.n_cpus]) - - edge_infos = mu.multithread_func( - _get_subgraph_layer2_edges, - np.array_split(child_ids, this_n_threads), - n_threads=this_n_threads, debug=this_n_threads == 1) - - affinities = np.array([], dtype=np.float32) - areas = np.array([], dtype=np.uint64) - edges = np.array([], dtype=np.uint64).reshape(0, 2) - - for edge_info in edge_infos: - _edges, _affinities, _areas = edge_info - areas = np.concatenate([areas, _areas]) - affinities = np.concatenate([affinities, _affinities]) - edges = np.concatenate([edges, _edges]) - - if verbose: - self.logger.debug("Layer %d: %.3fms for %d childs with %d threads" % - (2, (time.time() - time_start) * 1000, - n_child_ids, this_n_threads)) - - return edges, affinities, areas - - def get_subgraph_nodes(self, agglomeration_id: np.uint64, - bounding_box: Optional[Sequence[Sequence[int]]] = None, - bb_is_coordinate: bool = False, - return_layers: List[int] = [1], - verbose: bool = True) -> \ - Union[Dict[int, np.ndarray], np.ndarray]: - """ Return all nodes belonging to the specified agglomeration ID within - the defined bounding box and requested layers. - - :param agglomeration_id: np.uint64 - :param bounding_box: [[x_l, y_l, z_l], [x_h, y_h, z_h]] - :param bb_is_coordinate: bool - :param return_layers: List[int] - :param verbose: bool - :return: np.array of atomic IDs if single layer is requested, - Dict[int, np.array] if multiple layers are requested - """ - - def _get_subgraph_layer2_nodes(node_ids: Iterable[np.uint64]) -> np.ndarray: - return self.get_children(node_ids, flatten=True) - - stop_layer = np.min(return_layers) - bounding_box = self.normalize_bounding_box(bounding_box, - bb_is_coordinate) - - # Layer 3+ - if stop_layer >= 2: - nodes_per_layer = self._get_subgraph_higher_layer_nodes( - node_id=agglomeration_id, bounding_box=bounding_box, - return_layers=return_layers, verbose=verbose) - else: - # Need to retrieve layer 2 even if the user doesn't require it - nodes_per_layer = self._get_subgraph_higher_layer_nodes( - node_id=agglomeration_id, bounding_box=bounding_box, - return_layers=return_layers+[2], verbose=verbose) - - # Layer 2 - if verbose: - time_start = time.time() - - child_ids = nodes_per_layer[2] - if 2 not in return_layers: - del nodes_per_layer[2] - - # Use heuristic to guess the optimal number of threads - n_child_ids = len(child_ids) - this_n_threads = np.min([int(n_child_ids // 50000) + 1, mu.n_cpus]) - - child_ids = np.fromiter(chain.from_iterable(mu.multithread_func( - _get_subgraph_layer2_nodes, - np.array_split(child_ids, this_n_threads), - n_threads=this_n_threads, debug=this_n_threads == 1)), dtype=np.uint64) - - if verbose: - self.logger.debug("Layer 2: %.3fms for %d children with %d threads" % - ((time.time() - time_start) * 1000, n_child_ids, - this_n_threads)) - - nodes_per_layer[1] = child_ids - - if len(nodes_per_layer) == 1: - return list(nodes_per_layer.values())[0] - else: - return nodes_per_layer - - def flatten_row_dict(self, row_dict: Dict[column_keys._Column, - List[bigtable.row_data.Cell]]) -> Dict: - """ Flattens multiple entries to columns by appending them - - :param row_dict: dict - family key has to be resolved - :return: dict - """ - - flattened_row_dict = {} - for column, column_entries in row_dict.items(): - flattened_row_dict[column] = [] - - if len(column_entries) > 0: - for column_entry in column_entries[::-1]: - flattened_row_dict[column].append(column_entry.value) - - if np.isscalar(column_entry.value): - flattened_row_dict[column] = np.array(flattened_row_dict[column]) - else: - flattened_row_dict[column] = np.concatenate(flattened_row_dict[column]) - else: - flattened_row_dict[column] = column.deserialize(b'') - - if column == column_keys.Connectivity.Connected: - u_ids, c_ids = np.unique(flattened_row_dict[column], - return_counts=True) - flattened_row_dict[column] = u_ids[(c_ids % 2) == 1].astype(column.basetype) - - return flattened_row_dict - - def get_chunk_split_partners(self, atomic_id: np.uint64): - """ Finds all atomic nodes beloning to the same supervoxel before - chunking (affs == inf) - - :param atomic_id: np.uint64 - :return: list of np.uint64 - """ - - chunk_split_partners = [atomic_id] - atomic_ids = [atomic_id] - - while len(atomic_ids) > 0: - atomic_id = atomic_ids[0] - del atomic_ids[0] - - partners, affs, _ = self.get_atomic_partners(atomic_id, - include_connected_partners=True, - include_disconnected_partners=False) - - m = np.isinf(affs) - - inf_partners = partners[m] - new_chunk_split_partners = inf_partners[~np.in1d(inf_partners, chunk_split_partners)] - atomic_ids.extend(new_chunk_split_partners) - chunk_split_partners.extend(new_chunk_split_partners) - - return chunk_split_partners - - def get_all_original_partners(self, atomic_id: np.uint64): - """ Finds all partners from the unchunked region graph - Merges split supervoxels over chunk boundaries first (affs == inf) - - :param atomic_id: np.uint64 - :return: dict np.uint64 -> np.uint64 - """ - - atomic_ids = [atomic_id] - partner_dict = {} - - while len(atomic_ids) > 0: - atomic_id = atomic_ids[0] - del atomic_ids[0] - - partners, affs, _ = self.get_atomic_partners(atomic_id, - include_connected_partners=True, - include_disconnected_partners=False) - - m = np.isinf(affs) - partner_dict[atomic_id] = partners[~m] - - inf_partners = partners[m] - new_chunk_split_partners = inf_partners[ - ~np.in1d(inf_partners, list(partner_dict.keys()))] - atomic_ids.extend(new_chunk_split_partners) - - return partner_dict - - def get_atomic_node_partners(self, atomic_id: np.uint64, - time_stamp: datetime.datetime = get_max_time() - ) -> Dict: - """ Reads register partner ids - - :param atomic_id: np.uint64 - :param time_stamp: datetime.datetime - :return: dict - """ - col_partner = column_keys.Connectivity.Partner - col_connected = column_keys.Connectivity.Connected - columns = [col_partner, col_connected] - row_dict = self.read_node_id_row(atomic_id, columns=columns, - end_time=time_stamp, end_time_inclusive=True) - flattened_row_dict = self.flatten_row_dict(row_dict) - return flattened_row_dict[col_partner][flattened_row_dict[col_connected]] - - def _get_atomic_node_info_core(self, row_dict) -> Dict: - """ Reads connectivity information for a single node - - :param atomic_id: np.uint64 - :param time_stamp: datetime.datetime - :return: dict - """ - flattened_row_dict = self.flatten_row_dict(row_dict) - all_ids = np.arange(len(flattened_row_dict[column_keys.Connectivity.Partner]), - dtype=column_keys.Connectivity.Partner.basetype) - disconnected_m = ~np.in1d(all_ids, - flattened_row_dict[column_keys.Connectivity.Connected]) - flattened_row_dict[column_keys.Connectivity.Disconnected] = all_ids[disconnected_m] - - return flattened_row_dict - - def get_atomic_node_info(self, atomic_id: np.uint64, - time_stamp: datetime.datetime = get_max_time() - ) -> Dict: - """ Reads connectivity information for a single node - - :param atomic_id: np.uint64 - :param time_stamp: datetime.datetime - :return: dict - """ - columns = [column_keys.Connectivity.Connected, column_keys.Connectivity.Affinity, - column_keys.Connectivity.Area, column_keys.Connectivity.Partner, - column_keys.Hierarchy.Parent] - row_dict = self.read_node_id_row(atomic_id, columns=columns, - end_time=time_stamp, end_time_inclusive=True) - - return self._get_atomic_node_info_core(row_dict) - - def _get_atomic_partners_core(self, flattened_row_dict: Dict, - include_connected_partners=True, - include_disconnected_partners=False - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ Extracts the atomic partners and affinities for a given timestamp - - :param flattened_row_dict: dict - :param include_connected_partners: bool - :param include_disconnected_partners: bool - :return: list of np.ndarrays - """ - columns = [] - if include_connected_partners: - columns.append(column_keys.Connectivity.Connected) - if include_disconnected_partners: - columns.append(column_keys.Connectivity.Disconnected) - - included_ids = [] - for column in columns: - included_ids.extend(flattened_row_dict[column]) - - included_ids = np.array(included_ids, dtype=column_keys.Connectivity.Connected.basetype) - - areas = flattened_row_dict[column_keys.Connectivity.Area][included_ids] - affinities = flattened_row_dict[column_keys.Connectivity.Affinity][included_ids] - partners = flattened_row_dict[column_keys.Connectivity.Partner][included_ids] - - return partners, affinities, areas - - def get_atomic_partners(self, atomic_id: np.uint64, - include_connected_partners=True, - include_disconnected_partners=False, - time_stamp: Optional[datetime.datetime] = get_max_time() - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ Extracts the atomic partners and affinities for a given timestamp - - :param atomic_id: np.uint64 - :param include_connected_partners: bool - :param include_disconnected_partners: bool - :param time_stamp: None or datetime - :return: list of np.ndarrays - """ - assert include_connected_partners or include_disconnected_partners - - flattened_row_dict = self.get_atomic_node_info(atomic_id, - time_stamp=time_stamp) - - return self._get_atomic_partners_core(flattened_row_dict, - include_connected_partners, - include_disconnected_partners) - - def _retrieve_connectivity(self, dict_item: Tuple[np.uint64, Dict[column_keys._Column, List[bigtable.row_data.Cell]]], - connected_edges: bool = True): - node_id, row = dict_item - - tmp = set() - for x in itertools.chain.from_iterable(generation.value for generation in row[column_keys.Connectivity.Connected][::-1]): - tmp.remove(x) if x in tmp else tmp.add(x) - - connected_indices = np.fromiter(tmp, np.uint64) - if column_keys.Connectivity.Partner in row: - edges = np.fromiter(itertools.chain.from_iterable( - (node_id, partner_id) - for generation in row[column_keys.Connectivity.Partner][::-1] - for partner_id in generation.value), - dtype=basetypes.NODE_ID).reshape((-1, 2)) - edges = self._connected_or_not(edges, connected_indices, - connected_edges) - else: - edges = np.empty((0, 2), basetypes.NODE_ID) - - if column_keys.Connectivity.Affinity in row: - affinities = np.fromiter(itertools.chain.from_iterable( - generation.value for generation in row[column_keys.Connectivity.Affinity][::-1]), - dtype=basetypes.EDGE_AFFINITY) - affinities = self._connected_or_not(affinities, connected_indices, - connected_edges) - else: - affinities = np.empty(0, basetypes.EDGE_AFFINITY) - - if column_keys.Connectivity.Area in row: - areas = np.fromiter(itertools.chain.from_iterable( - generation.value for generation in row[column_keys.Connectivity.Area][::-1]), - dtype=basetypes.EDGE_AREA) - areas = self._connected_or_not(areas, connected_indices, - connected_edges) - else: - areas = np.empty(0, basetypes.EDGE_AREA) - - return edges, affinities, areas - - def _connected_or_not(self, array, connected_indices, connected): - """ - Either filters the first dimension of a numpy array by the passed - indices or their complement. Used to select edge descriptors for - those that are either connected or not connected. - """ - mask = np.zeros((array.shape[0],), dtype=np.bool) - mask[connected_indices] = True - - if connected: - return array[mask, ...] - else: - return array[~mask, ...] - - def get_subgraph_chunk(self, node_ids: Iterable[np.uint64], - make_unique: bool = True, - connected_edges: bool = True, - time_stamp: Optional[datetime.datetime] = None - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ Takes an atomic id and returns the associated agglomeration ids - - :param node_ids: array of np.uint64 - :param make_unique: bool - :param connected_edges: bool - :param time_stamp: None or datetime - :return: edge list - """ - if time_stamp is None: - time_stamp = datetime.datetime.utcnow() - - if time_stamp.tzinfo is None: - time_stamp = UTC.localize(time_stamp) - - child_ids = self.get_children(node_ids, flatten=True) - - row_dict = self.read_node_id_rows(node_ids=child_ids, - columns=[column_keys.Connectivity.Area, - column_keys.Connectivity.Affinity, - column_keys.Connectivity.Partner, - column_keys.Connectivity.Connected, - column_keys.Connectivity.Disconnected], - end_time=time_stamp, - end_time_inclusive=True) - - tmp_edges, tmp_affinites, tmp_areas = [], [], [] - for row_dict_item in row_dict.items(): - edges, affinities, areas = self._retrieve_connectivity(row_dict_item, - connected_edges) - tmp_edges.append(edges) - tmp_affinites.append(affinities) - tmp_areas.append(areas) - - edges = np.concatenate(tmp_edges) if tmp_edges \ - else np.empty((0, 2), dtype=basetypes.NODE_ID) - affinities = np.concatenate(tmp_affinites) if tmp_affinites \ - else np.empty(0, dtype=basetypes.EDGE_AFFINITY) - areas = np.concatenate(tmp_areas) if tmp_areas \ - else np.empty(0, dtype=basetypes.EDGE_AREA) - - # If requested, remove duplicate edges. Every edge is stored in each - # participating node. Hence, we have many edge pairs that look - # like [x, y], [y, x]. We solve this by sorting and calling np.unique - # row-wise - if make_unique and len(edges) > 0: - edges, idx = np.unique(np.sort(edges, axis=1), axis=0, - return_index=True) - affinities = affinities[idx] - areas = areas[idx] - - return edges, affinities, areas - - def add_edges( - self, - user_id: str, - atomic_edges: Sequence[np.uint64], - affinities: Sequence[np.float32] = None, - source_coord: Sequence[int] = None, - sink_coord: Sequence[int] = None, - n_tries: int = 60, - ) -> GraphEditOperation.Result: - """ Adds an edge to the chunkedgraph - - Multi-user safe through locking of the root node - - This function acquires a lock and ensures that it still owns the - lock before executing the write. - - :param user_id: str - unique id - do not just make something up, use the same id for the - same user every time - :param atomic_edges: list of two uint64s - have to be from the same two root ids! - :param affinities: list of np.float32 or None - will eventually be set to 1 if None - :param source_coord: list of int (n x 3) - :param sink_coord: list of int (n x 3) - :param n_tries: int - :return: GraphEditOperation.Result - """ - return MergeOperation( - self, - user_id=user_id, - added_edges=atomic_edges, - affinities=affinities, - source_coords=source_coord, - sink_coords=sink_coord, - ).execute() - - def remove_edges( - self, - user_id: str, - source_ids: Sequence[np.uint64] = None, - sink_ids: Sequence[np.uint64] = None, - source_coords: Sequence[Sequence[int]] = None, - sink_coords: Sequence[Sequence[int]] = None, - atomic_edges: Sequence[Tuple[np.uint64, np.uint64]] = None, - mincut: bool = True, - bb_offset: Tuple[int, int, int] = (240, 240, 24), - n_tries: int = 20, - ) -> GraphEditOperation.Result: - """ Removes edges - either directly or after applying a mincut - - Multi-user safe through locking of the root node - - This function acquires a lock and ensures that it still owns the - lock before executing the write. - - :param user_id: str - unique id - do not just make something up, use the same id for the - same user every time - :param source_ids: uint64 - :param sink_ids: uint64 - :param atomic_edges: list of 2 uint64 - :param source_coords: list of 3 ints - [x, y, z] coordinate of source supervoxel - :param sink_coords: list of 3 ints - [x, y, z] coordinate of sink supervoxel - :param mincut: - :param bb_offset: list of 3 ints - [x, y, z] bounding box padding beyond box spanned by coordinates - :param n_tries: int - :return: GraphEditOperation.Result - """ - if mincut: - return MulticutOperation( - self, - user_id=user_id, - source_ids=source_ids, - sink_ids=sink_ids, - source_coords=source_coords, - sink_coords=sink_coords, - bbox_offset=bb_offset, - ).execute() - - if not atomic_edges: - # Shim - can remove this check once all functions call the split properly/directly - source_ids = [source_ids] if np.isscalar(source_ids) else source_ids - sink_ids = [sink_ids] if np.isscalar(sink_ids) else sink_ids - if len(source_ids) != len(sink_ids): - raise cg_exceptions.PreconditionError( - "Split operation require the same number of source and sink IDs" - ) - atomic_edges = np.array([source_ids, sink_ids]).transpose() - - return SplitOperation( - self, - user_id=user_id, - removed_edges=atomic_edges, - source_coords=source_coords, - sink_coords=sink_coords, - ).execute() - - def undo_operation(self, user_id: str, operation_id: np.uint64) -> GraphEditOperation.Result: - """ Applies the inverse of a previous GraphEditOperation - - :param user_id: str - :param operation_id: operation_id to be inverted - :return: GraphEditOperation.Result - """ - return UndoOperation(self, user_id=user_id, operation_id=operation_id).execute() - - def redo_operation(self, user_id: str, operation_id: np.uint64) -> GraphEditOperation.Result: - """ Re-applies a previous GraphEditOperation - - :param user_id: str - :param operation_id: operation_id to be repeated - :return: GraphEditOperation.Result - """ - return RedoOperation(self, user_id=user_id, operation_id=operation_id).execute() - - def _run_multicut(self, source_ids: Sequence[np.uint64], - sink_ids: Sequence[np.uint64], - source_coords: Sequence[Sequence[int]], - sink_coords: Sequence[Sequence[int]], - bb_offset: Tuple[int, int, int] = (120, 120, 12)): - - - time_start = time.time() - - bb_offset = np.array(list(bb_offset)) - source_coords = np.array(source_coords) - sink_coords = np.array(sink_coords) - - # Decide a reasonable bounding box (NOT guaranteed to be successful!) - coords = np.concatenate([source_coords, sink_coords]) - bounding_box = [np.min(coords, axis=0), np.max(coords, axis=0)] - - bounding_box[0] -= bb_offset - bounding_box[1] += bb_offset - - # Verify that sink and source are from the same root object - root_ids = set() - for source_id in source_ids: - root_ids.add(self.get_root(source_id)) - for sink_id in sink_ids: - root_ids.add(self.get_root(sink_id)) - - if len(root_ids) > 1: - raise cg_exceptions.PreconditionError( - f"All supervoxel must belong to the same object. Already split?" - ) - - self.logger.debug("Get roots and check: %.3fms" % - ((time.time() - time_start) * 1000)) - time_start = time.time() # ------------------------------------------ - - root_id = root_ids.pop() - - # Get edges between local supervoxels - n_chunks_affected = np.product((np.ceil(bounding_box[1] / self.chunk_size)).astype(np.int) - - (np.floor(bounding_box[0] / self.chunk_size)).astype(np.int)) - - self.logger.debug("Number of affected chunks: %d" % n_chunks_affected) - self.logger.debug(f"Bounding box: {bounding_box}") - self.logger.debug(f"Bounding box padding: {bb_offset}") - self.logger.debug(f"Source ids: {source_ids}") - self.logger.debug(f"Sink ids: {sink_ids}") - self.logger.debug(f"Root id: {root_id}") - - edges, affs, areas = self.get_subgraph_edges(root_id, - bounding_box=bounding_box, - bb_is_coordinate=True) - self.logger.debug(f"Get edges and affs: " - f"{(time.time() - time_start) * 1000:.3f}ms") - - time_start = time.time() # ------------------------------------------ - - if len(edges) == 0: - raise cg_exceptions.PreconditionError( - f"No local edges found. " - f"Something went wrong with the bounding box?" - ) - - # Compute mincut - atomic_edges = cutting.mincut(edges, affs, source_ids, sink_ids) - - self.logger.debug(f"Mincut: {(time.time() - time_start) * 1000:.3f}ms") - - if len(atomic_edges) == 0: - raise cg_exceptions.PostconditionError( - f"Mincut failed. Try again...") - - # # Check if any edge in the cutset is infinite (== between chunks) - # # We would prevent such a cut - # - # atomic_edges_flattened_view = atomic_edges.view(dtype='u8,u8') - # edges_flattened_view = edges.view(dtype='u8,u8') - # - # cutset_mask = np.in1d(edges_flattened_view, atomic_edges_flattened_view) - # - # if np.any(np.isinf(affs[cutset_mask])): - # self.logger.error("inf in cutset") - # return False, None - - return atomic_edges diff --git a/pychunkedgraph/backend/chunkedgraph_comp.py b/pychunkedgraph/backend/chunkedgraph_comp.py deleted file mode 100644 index 10d99b85c..000000000 --- a/pychunkedgraph/backend/chunkedgraph_comp.py +++ /dev/null @@ -1,221 +0,0 @@ -import numpy as np -import datetime -import collections - -from pychunkedgraph.backend import chunkedgraph, flatgraph_utils -from pychunkedgraph.backend.utils import column_keys - -from multiwrapper import multiprocessing_utils as mu - -from typing import Optional, Sequence - - -def _read_delta_root_rows_thread(args) -> Sequence[list]: - start_seg_id, end_seg_id, serialized_cg_info, time_stamp_start, time_stamp_end = args - - cg = chunkedgraph.ChunkedGraph(**serialized_cg_info) - - start_id = cg.get_node_id(segment_id=start_seg_id, - chunk_id=cg.root_chunk_id) - end_id = cg.get_node_id(segment_id=end_seg_id, - chunk_id=cg.root_chunk_id) - - # apply column filters to avoid Lock columns - rows = cg.read_node_id_rows( - start_id=start_id, - start_time=time_stamp_start, - end_id=end_id, - end_id_inclusive=False, - columns=[column_keys.Hierarchy.FormerParent, column_keys.Hierarchy.NewParent], - end_time=time_stamp_end, - end_time_inclusive=True) - - # new roots are those that have no NewParent in this time window - new_root_ids = [k for (k, v) in rows.items() - if column_keys.Hierarchy.NewParent not in v] - - # expired roots are the IDs of FormerParent's - # whose timestamp is before the start_time - expired_root_ids = [] - for k, v in rows.items(): - if column_keys.Hierarchy.FormerParent in v: - fp = v[column_keys.Hierarchy.FormerParent] - for cell_entry in fp: - expired_root_ids.extend(cell_entry.value) - - return new_root_ids, expired_root_ids - - -def _read_root_rows_thread(args) -> list: - start_seg_id, end_seg_id, serialized_cg_info, time_stamp = args - - cg = chunkedgraph.ChunkedGraph(**serialized_cg_info) - - start_id = cg.get_node_id(segment_id=start_seg_id, - chunk_id=cg.root_chunk_id) - end_id = cg.get_node_id(segment_id=end_seg_id, - chunk_id=cg.root_chunk_id) - - rows = cg.read_node_id_rows( - start_id=start_id, - end_id=end_id, - end_id_inclusive=False, - end_time=time_stamp, - end_time_inclusive=True) - - root_ids = [k for (k, v) in rows.items() - if column_keys.Hierarchy.NewParent not in v] - - return root_ids - - -def get_latest_roots(cg, - time_stamp: Optional[datetime.datetime] = None, - n_threads: int = 1) -> Sequence[np.uint64]: - - # Create filters: time and id range - max_seg_id = cg.get_max_seg_id(cg.root_chunk_id) + 1 - - if n_threads == 1: - n_blocks = 1 - else: - n_blocks = int(np.min([n_threads * 3 + 1, max_seg_id])) - - seg_id_blocks = np.linspace(1, max_seg_id, n_blocks + 1, dtype=np.uint64) - - cg_serialized_info = cg.get_serialized_info() - - if n_threads > 1: - del cg_serialized_info["credentials"] - - multi_args = [] - for i_id_block in range(0, len(seg_id_blocks) - 1): - multi_args.append([seg_id_blocks[i_id_block], - seg_id_blocks[i_id_block + 1], - cg_serialized_info, time_stamp]) - - if n_threads == 1: - results = mu.multiprocess_func(_read_root_rows_thread, - multi_args, n_threads=n_threads, - verbose=False, debug=n_threads == 1) - else: - results = mu.multisubprocess_func(_read_root_rows_thread, - multi_args, n_threads=n_threads) - - root_ids = [] - for result in results: - root_ids.extend(result) - - return np.array(root_ids, dtype=np.uint64) - - -def get_delta_roots(cg, - time_stamp_start: datetime.datetime, - time_stamp_end: Optional[datetime.datetime] = None, - min_seg_id: int = 1, - n_threads: int = 1) -> Sequence[np.uint64]: - - # Create filters: time and id range - max_seg_id = cg.get_max_seg_id(cg.root_chunk_id) + 1 - - n_blocks = int(np.min([n_threads + 1, max_seg_id-min_seg_id+1])) - seg_id_blocks = np.linspace(min_seg_id, max_seg_id, n_blocks, - dtype=np.uint64) - - cg_serialized_info = cg.get_serialized_info() - - if n_threads > 1: - del cg_serialized_info["credentials"] - - multi_args = [] - for i_id_block in range(0, len(seg_id_blocks) - 1): - multi_args.append([seg_id_blocks[i_id_block], - seg_id_blocks[i_id_block + 1], - cg_serialized_info, time_stamp_start, time_stamp_end]) - - # Run parallelizing - if n_threads == 1: - results = mu.multiprocess_func(_read_delta_root_rows_thread, - multi_args, n_threads=n_threads, - verbose=False, debug=n_threads == 1) - else: - results = mu.multisubprocess_func(_read_delta_root_rows_thread, - multi_args, n_threads=n_threads) - - # aggregate all the results together - new_root_ids = [] - expired_root_id_candidates = [] - for r1, r2 in results: - new_root_ids.extend(r1) - expired_root_id_candidates.extend(r2) - expired_root_id_candidates = np.array(expired_root_id_candidates, dtype=np.uint64) - # filter for uniqueness - expired_root_id_candidates = np.unique(expired_root_id_candidates) - - # filter out the expired root id's whose creation (measured by the timestamp - # of their Child links) is after the time_stamp_start - rows = cg.read_node_id_rows(node_ids=expired_root_id_candidates, - columns=[column_keys.Hierarchy.Child], - end_time=time_stamp_start) - expired_root_ids = np.array([k for (k, v) in rows.items()], dtype=np.uint64) - - return np.array(new_root_ids, dtype=np.uint64), expired_root_ids - - -def get_contact_sites(cg, root_id, bounding_box=None, bb_is_coordinate=True, compute_partner=True): - # Get information about the root id - # All supervoxels - sv_ids = cg.get_subgraph_nodes(root_id, - bounding_box=bounding_box, - bb_is_coordinate=bb_is_coordinate) - # All edges that are _not_ connected / on - edges, affs, areas = cg.get_subgraph_edges(root_id, - bounding_box=bounding_box, - bb_is_coordinate=bb_is_coordinate, - connected_edges=False) - - # Build area lookup dictionary - cs_svs = edges[~np.in1d(edges, sv_ids).reshape(-1, 2)] - area_dict = collections.defaultdict(int) - - for area, sv_id in zip(areas, cs_svs): - area_dict[sv_id] += area - - area_dict_vec = np.vectorize(area_dict.get) - - # Extract svs from contacting root ids - u_cs_svs = np.unique(cs_svs) - - # Load edges of these cs_svs - edges_cs_svs_rows = cg.read_node_id_rows(node_ids=u_cs_svs, - columns=[column_keys.Connectivity.Partner, - column_keys.Connectivity.Connected]) - - pre_cs_edges = [] - for ri in edges_cs_svs_rows.items(): - r = cg._retrieve_connectivity(ri) - pre_cs_edges.extend(r[0]) - - graph, _, _, unique_ids = flatgraph_utils.build_gt_graph( - pre_cs_edges, make_directed=True) - - # connected components in this graph will be combined in one component - ccs = flatgraph_utils.connected_components(graph) - - cs_dict = collections.defaultdict(list) - for cc in ccs: - cc_sv_ids = unique_ids[cc] - - cc_sv_ids = cc_sv_ids[np.in1d(cc_sv_ids, u_cs_svs)] - cs_areas = area_dict_vec(cc_sv_ids) - - if compute_partner: - partner_root_id = int(cg.get_root(cc_sv_ids[0])) - else: - partner_root_id = len(cs_dict) - - print(partner_root_id, np.sum(cs_areas)) - - cs_dict[partner_root_id].append(np.sum(cs_areas)) - - return cs_dict \ No newline at end of file diff --git a/pychunkedgraph/backend/chunkedgraph_edits.py b/pychunkedgraph/backend/chunkedgraph_edits.py deleted file mode 100644 index 5f42d9b99..000000000 --- a/pychunkedgraph/backend/chunkedgraph_edits.py +++ /dev/null @@ -1,905 +0,0 @@ -import datetime -import numpy as np -import collections - -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union,\ - NamedTuple - -from multiwrapper import multiprocessing_utils as mu - -from pychunkedgraph.backend.chunkedgraph_utils \ - import get_google_compatible_time_stamp, combine_cross_chunk_edge_dicts -from pychunkedgraph.backend.utils import column_keys, serializers -from pychunkedgraph.backend import flatgraph_utils - -def _write_atomic_merge_edges(cg, atomic_edges, affinities, areas, time_stamp): - rows = [] - - if areas is None: - areas = np.zeros(len(atomic_edges), - dtype=column_keys.Connectivity.Area.basetype) - if affinities is None: - affinities = np.ones(len(atomic_edges)) * np.inf - affinities = affinities.astype(column_keys.Connectivity.Affinity.basetype) - - rows = [] - - u_atomic_ids = np.unique(atomic_edges) - for u_atomic_id in u_atomic_ids: - val_dict = {} - atomic_node_info = cg.get_atomic_node_info(u_atomic_id) - - edge_m0 = atomic_edges[:, 0] == u_atomic_id - edge_m1 = atomic_edges[:, 1] == u_atomic_id - - edge_partners = np.concatenate([atomic_edges[edge_m1][:, 0], - atomic_edges[edge_m0][:, 1]]) - edge_affs = np.concatenate([affinities[edge_m1], affinities[edge_m0]]) - edge_areas = np.concatenate([areas[edge_m1], areas[edge_m0]]) - - ex_partner_m = np.in1d(edge_partners, atomic_node_info[column_keys.Connectivity.Partner]) - partner_idx = np.where( - np.in1d(atomic_node_info[column_keys.Connectivity.Partner], - edge_partners[ex_partner_m]))[0] - - n_ex_partners = len(atomic_node_info[column_keys.Connectivity.Partner]) - - new_partner_idx = np.arange(n_ex_partners, - n_ex_partners + np.sum(~ex_partner_m)) - partner_idx = np.concatenate([partner_idx, new_partner_idx]) - partner_idx = np.array(partner_idx, - dtype=column_keys.Connectivity.Connected.basetype) - - val_dict[column_keys.Connectivity.Connected] = partner_idx - - if np.sum(~ex_partner_m) > 0: - edge_affs = edge_affs[~ex_partner_m] - edge_areas = edge_areas[~ex_partner_m] - - new_edge_partners = np.array(edge_partners[~ex_partner_m], - dtype=column_keys.Connectivity.Partner.basetype) - - val_dict[column_keys.Connectivity.Affinity] = edge_affs - val_dict[column_keys.Connectivity.Area] = edge_areas - val_dict[column_keys.Connectivity.Partner] = new_edge_partners - - rows.append(cg.mutate_row(serializers.serialize_uint64(u_atomic_id), - val_dict, time_stamp=time_stamp)) - - return rows - - -def _write_atomic_split_edges(cg, atomic_edges, time_stamp): - rows = [] - - u_atomic_ids = np.unique(atomic_edges) - for u_atomic_id in u_atomic_ids: - atomic_node_info = cg.get_atomic_node_info(u_atomic_id) - - partners = np.concatenate( - [atomic_edges[atomic_edges[:, 0] == u_atomic_id][:, 1], - atomic_edges[atomic_edges[:, 1] == u_atomic_id][:, 0]]) - - partner_idx = np.where( - np.in1d(atomic_node_info[column_keys.Connectivity.Partner], - partners))[0] - - partner_idx = np.array(partner_idx, - dtype=column_keys.Connectivity.Connected.basetype) - - val_dict = {column_keys.Connectivity.Connected: partner_idx} - rows.append(cg.mutate_row(serializers.serialize_uint64(u_atomic_id), - val_dict, time_stamp=time_stamp)) - - return rows - - -def analyze_atomic_edges(cg, atomic_edges): - lvl2_edges = [] - edge_layers = cg.get_cross_chunk_edges_layer(atomic_edges) - edge_layer_m = edge_layers > 1 - - # Edges are either within or across chunks. If an edge is across a - # chunk boundary we need to store it as cross edge. Otherwise, this - # edge will combine two formerly disconnected lvl2 segments. - cross_edge_dict = {} - for atomic_edge in atomic_edges[~edge_layer_m]: - lvl2_edges.append([cg.get_parent(atomic_edge[0]), - cg.get_parent(atomic_edge[1])]) - - for atomic_edge, layer in zip(atomic_edges[edge_layer_m], - edge_layers[edge_layer_m]): - parent_id_0 = cg.get_parent(atomic_edge[0]) - parent_id_1 = cg.get_parent(atomic_edge[1]) - - cross_edge_dict[parent_id_0] = {layer: atomic_edge} - cross_edge_dict[parent_id_1] = {layer: atomic_edge[::-1]} - - lvl2_edges.append([parent_id_0, parent_id_0]) - lvl2_edges.append([parent_id_1, parent_id_1]) - - return lvl2_edges, cross_edge_dict - - -def add_edges(cg, - operation_id: np.uint64, - atomic_edges: Sequence[Sequence[np.uint64]], - time_stamp: datetime.datetime, - areas: Optional[Sequence[np.uint64]] = None, - affinities: Optional[Sequence[np.float32]] = None - ): - """ Add edges to chunkedgraph - - Computes all new rows to be written to the chunkedgraph - - :param cg: ChunkedGraph instance - :param operation_id: np.uint64 - :param atomic_edges: list of list of np.uint64 - edges between supervoxels - :param time_stamp: datetime.datetime - :param areas: list of np.uint64 - :param affinities: list of np.float32 - :return: list - """ - def _read_cc_edges_thread(node_ids): - for node_id in node_ids: - cc_dict[node_id] = cg.read_cross_chunk_edges(node_id) - - cc_dict = {} - - atomic_edges = np.array(atomic_edges, - dtype=column_keys.Connectivity.Partner.basetype) - - # # Comply to resolution of BigTables TimeRange - # time_stamp = get_google_compatible_time_stamp(time_stamp, - # round_up=False) - - if affinities is None: - affinities = np.ones(len(atomic_edges), - dtype=column_keys.Connectivity.Affinity.basetype) - else: - affinities = np.array(affinities, - dtype=column_keys.Connectivity.Affinity.basetype) - - if areas is None: - areas = np.ones(len(atomic_edges), - dtype=column_keys.Connectivity.Area.basetype) * np.inf - else: - areas = np.array(areas, - dtype=column_keys.Connectivity.Area.basetype) - - assert len(affinities) == len(atomic_edges) - - rows = [] # list of rows to be written to BigTable - lvl2_dict = {} - lvl2_cross_chunk_edge_dict = {} - - # Analyze atomic_edges --> translate them to lvl2 edges and extract cross - # chunk edges - lvl2_edges, new_cross_edge_dict = analyze_atomic_edges(cg, atomic_edges) - - # Compute connected components on lvl2 - graph, _, _, unique_graph_ids = flatgraph_utils.build_gt_graph( - lvl2_edges, make_directed=True) - - # Read cross chunk edges efficiently - cc_dict = {} - node_ids = np.unique(lvl2_edges) - n_threads = int(np.ceil(len(node_ids) / 5)) - - node_id_blocks = np.array_split(node_ids, n_threads) - - mu.multithread_func(_read_cc_edges_thread, node_id_blocks, - n_threads=n_threads, debug=False) - - ccs = flatgraph_utils.connected_components(graph) - for cc in ccs: - lvl2_ids = unique_graph_ids[cc] - chunk_id = cg.get_chunk_id(lvl2_ids[0]) - - new_node_id = cg.get_unique_node_id(chunk_id) - lvl2_dict[new_node_id] = lvl2_ids - - cross_chunk_edge_dict = {} - for lvl2_id in lvl2_ids: - lvl2_id_cross_chunk_edges = cc_dict[lvl2_id] - cross_chunk_edge_dict = \ - combine_cross_chunk_edge_dicts( - cross_chunk_edge_dict, - lvl2_id_cross_chunk_edges) - - if lvl2_id in new_cross_edge_dict: - cross_chunk_edge_dict = \ - combine_cross_chunk_edge_dicts( - new_cross_edge_dict[lvl2_id], - lvl2_id_cross_chunk_edges) - - lvl2_cross_chunk_edge_dict[new_node_id] = cross_chunk_edge_dict - - if cg.n_layers == 2: - rows.extend(update_root_id_lineage(cg, [new_node_id], lvl2_ids, - operation_id=operation_id, - time_stamp=time_stamp)) - - children_ids = cg.get_children(lvl2_ids, flatten=True) - - rows.extend(create_parent_children_rows(cg, new_node_id, children_ids, - cross_chunk_edge_dict, - time_stamp)) - - # Write atomic nodes - rows.extend(_write_atomic_merge_edges(cg, atomic_edges, affinities, areas, - time_stamp=time_stamp)) - - # Propagate changes up the tree - if cg.n_layers > 2: - new_root_ids, new_rows = propagate_edits_to_root( - cg, lvl2_dict.copy(), lvl2_cross_chunk_edge_dict, - operation_id=operation_id, time_stamp=time_stamp) - rows.extend(new_rows) - else: - new_root_ids = np.array(list(lvl2_dict.keys())) - - - return new_root_ids, list(lvl2_dict.keys()), rows - - -def remove_edges(cg, operation_id: np.uint64, - atomic_edges: Sequence[Sequence[np.uint64]], - time_stamp: datetime.datetime): - - # This view of the to be removed edges helps us to compute the mask - # of the retained edges in each chunk - double_atomic_edges = np.concatenate([atomic_edges, - atomic_edges[:, ::-1]], - axis=0) - double_atomic_edges_view = double_atomic_edges.view(dtype='u8,u8') - n_edges = double_atomic_edges.shape[0] - double_atomic_edges_view = double_atomic_edges_view.reshape(n_edges) - - rows = [] # list of rows to be written to BigTable - lvl2_dict = {} - lvl2_cross_chunk_edge_dict = {} - - # Analyze atomic_edges --> translate them to lvl2 edges and extract cross - # chunk edges to be removed - lvl2_edges, old_cross_edge_dict = analyze_atomic_edges(cg, atomic_edges) - lvl2_node_ids = np.unique(lvl2_edges) - - for lvl2_node_id in lvl2_node_ids: - chunk_id = cg.get_chunk_id(lvl2_node_id) - chunk_edges, _, _ = cg.get_subgraph_chunk(lvl2_node_id, - make_unique=False) - - child_chunk_ids = cg.get_child_chunk_ids(chunk_id) - - assert len(child_chunk_ids) == 1 - child_chunk_id = child_chunk_ids[0] - - children_ids = np.unique(chunk_edges) - children_chunk_ids = cg.get_chunk_ids_from_node_ids(children_ids) - children_ids = children_ids[children_chunk_ids == child_chunk_id] - - # These edges still contain the removed edges. - # For consistency reasons we can only write to BigTable one time. - # Hence, we have to evict the to be removed "atomic_edges" from the - # queried edges. - retained_edges_mask = ~np.in1d( - chunk_edges.view(dtype='u8,u8').reshape(chunk_edges.shape[0]), - double_atomic_edges_view) - - chunk_edges = chunk_edges[retained_edges_mask] - - edge_layers = cg.get_cross_chunk_edges_layer(chunk_edges) - cross_edge_mask = edge_layers != 1 - - cross_edges = chunk_edges[cross_edge_mask] - cross_edge_layers = edge_layers[cross_edge_mask] - chunk_edges = chunk_edges[~cross_edge_mask] - - isolated_child_ids = children_ids[~np.in1d(children_ids, chunk_edges)] - isolated_edges = np.vstack([isolated_child_ids, isolated_child_ids]).T - - graph, _, _, unique_graph_ids = flatgraph_utils.build_gt_graph( - np.concatenate([chunk_edges, isolated_edges]), make_directed=True) - - ccs = flatgraph_utils.connected_components(graph) - - new_parent_ids = cg.get_unique_node_id_range(chunk_id, len(ccs)) - - for i_cc, cc in enumerate(ccs): - new_parent_id = new_parent_ids[i_cc] - cc_node_ids = unique_graph_ids[cc] - - lvl2_dict[new_parent_id] = [lvl2_node_id] - - # Write changes to atomic nodes and new lvl2 parent row - val_dict = {column_keys.Hierarchy.Child: cc_node_ids} - rows.append(cg.mutate_row( - serializers.serialize_uint64(new_parent_id), - val_dict, time_stamp=time_stamp)) - - for cc_node_id in cc_node_ids: - val_dict = {column_keys.Hierarchy.Parent: new_parent_id} - - rows.append(cg.mutate_row( - serializers.serialize_uint64(cc_node_id), - val_dict, time_stamp=time_stamp)) - - # Cross edges --- - cross_edge_m = np.in1d(cross_edges[:, 0], cc_node_ids) - cc_cross_edges = cross_edges[cross_edge_m] - cc_cross_edge_layers = cross_edge_layers[cross_edge_m] - u_cc_cross_edge_layers = np.unique(cc_cross_edge_layers) - - lvl2_cross_chunk_edge_dict[new_parent_id] = {} - - for l in range(2, cg.n_layers): - empty_edges = column_keys.Connectivity.CrossChunkEdge.deserialize(b'') - lvl2_cross_chunk_edge_dict[new_parent_id][l] = empty_edges - - val_dict = {} - for cc_layer in u_cc_cross_edge_layers: - edge_m = cc_cross_edge_layers == cc_layer - layer_cross_edges = cc_cross_edges[edge_m] - - if len(layer_cross_edges) > 0: - val_dict[column_keys.Connectivity.CrossChunkEdge[cc_layer]] = \ - layer_cross_edges - lvl2_cross_chunk_edge_dict[new_parent_id][cc_layer] = layer_cross_edges - - if len(val_dict) > 0: - rows.append(cg.mutate_row( - serializers.serialize_uint64(new_parent_id), - val_dict, time_stamp=time_stamp)) - - if cg.n_layers == 2: - rows.extend(update_root_id_lineage(cg, new_parent_ids, - [lvl2_node_id], - operation_id=operation_id, - time_stamp=time_stamp)) - - # Write atomic nodes - rows.extend(_write_atomic_split_edges(cg, atomic_edges, - time_stamp=time_stamp)) - - # Propagate changes up the tree - if cg.n_layers > 2: - new_root_ids, new_rows = propagate_edits_to_root( - cg, lvl2_dict.copy(), lvl2_cross_chunk_edge_dict, - operation_id=operation_id, time_stamp=time_stamp) - rows.extend(new_rows) - else: - new_root_ids = np.array(list(lvl2_dict.keys())) - - return new_root_ids, list(lvl2_dict.keys()), rows - - -def old_parent_childrens(eh, node_ids, layer): - """ Retrieves the former partners of new nodes - - Two steps - 1. acquire old parents - 2. read children of those old parents - - :param eh: EditHelper instance - :param node_ids: list of np.uint64s - :param layer: np.int - :return: - """ - assert len(node_ids) > 0 - assert np.sum(np.in1d(node_ids, eh.new_node_ids)) == len(node_ids) - - # 1 - gather all next layer parents - old_next_layer_node_ids = [] - old_this_layer_node_ids = [] - for node_id in node_ids: - old_next_layer_node_ids.extend( - eh.get_old_node_ids(node_id, layer + 1)) - - old_this_layer_node_ids.extend( - eh.get_old_node_ids(node_id, layer)) - - old_next_layer_node_ids = np.unique(old_next_layer_node_ids) - next_layer_m = eh.cg.get_chunk_layers(old_next_layer_node_ids) == layer + 1 - old_next_layer_node_ids = old_next_layer_node_ids[next_layer_m] - - old_this_layer_node_ids = np.unique(old_this_layer_node_ids) - this_layer_m = eh.cg.get_chunk_layers(old_this_layer_node_ids) == layer - old_this_layer_node_ids = old_this_layer_node_ids[this_layer_m] - - # 2 - acquire their children - old_this_layer_partner_ids = [] - for old_next_layer_node_id in old_next_layer_node_ids: - partner_ids = eh.get_layer_children(old_next_layer_node_id, layer, - layer_only=True) - - partner_ids = partner_ids[~np.in1d(partner_ids, - old_this_layer_node_ids)] - old_this_layer_partner_ids.extend(partner_ids) - - old_this_layer_partner_ids = np.unique(old_this_layer_partner_ids) - - return old_this_layer_node_ids, old_next_layer_node_ids, \ - old_this_layer_partner_ids - - -def compute_cross_chunk_connected_components(eh, node_ids, layer): - """ Computes connected component for next layer - - :param eh: EditHelper - :param node_ids: list of np.uint64s - :param layer: np.int - :return: - """ - assert len(node_ids) > 0 - - # On each layer we build the a graph with all cross chunk edges - # that involve the nodes on the current layer - # To do this efficiently, we acquire all candidate same layer nodes - # that were previously connected to any of the currently assessed - # nodes. In practice, we (1) gather all relevant parents in the next - # layer and then (2) acquire their children - - old_this_layer_node_ids, old_next_layer_node_ids, \ - old_this_layer_partner_ids = \ - old_parent_childrens(eh, node_ids, layer) - - # Build network from cross chunk edges - edge_id_map = {} - cross_edges_lvl1 = [] - for node_id in node_ids: - node_cross_edges = eh.read_cross_chunk_edges(node_id)[layer] - edge_id_map.update(dict(zip(node_cross_edges[:, 0], - [node_id] * len(node_cross_edges)))) - cross_edges_lvl1.extend(node_cross_edges) - - for old_partner_id in old_this_layer_partner_ids: - node_cross_edges = eh.read_cross_chunk_edges(old_partner_id)[layer] - - edge_id_map.update(dict(zip(node_cross_edges[:, 0], - [old_partner_id] * len(node_cross_edges)))) - cross_edges_lvl1.extend(node_cross_edges) - - cross_edges_lvl1 = np.array(cross_edges_lvl1) - edge_id_map_vec = np.vectorize(edge_id_map.get) - - if len(cross_edges_lvl1) > 0: - cross_edges = edge_id_map_vec(cross_edges_lvl1) - else: - cross_edges = np.empty([0, 2], dtype=np.uint64) - - assert np.sum(np.in1d(eh.old_node_ids, cross_edges)) == 0 - - cross_edges = np.concatenate([cross_edges, - np.vstack([node_ids, node_ids]).T]) - - graph, _, _, unique_graph_ids = flatgraph_utils.build_gt_graph( - cross_edges, make_directed=True) - - ccs = flatgraph_utils.connected_components(graph) - - return ccs, unique_graph_ids - - -def update_root_id_lineage(cg, new_root_ids, former_root_ids, operation_id, - time_stamp): - assert len(former_root_ids) < 2 or len(new_root_ids) < 2 - - rows = [] - - for new_root_id in new_root_ids: - val_dict = {column_keys.Hierarchy.FormerParent: np.array(former_root_ids), - column_keys.OperationLogs.OperationID: operation_id} - rows.append(cg.mutate_row(serializers.serialize_uint64(new_root_id), - val_dict, time_stamp=time_stamp)) - - for former_root_id in former_root_ids: - val_dict = {column_keys.Hierarchy.NewParent: np.array(new_root_ids), - column_keys.OperationLogs.OperationID: operation_id} - rows.append(cg.mutate_row(serializers.serialize_uint64(former_root_id), - val_dict, time_stamp=time_stamp)) - - return rows - -def create_parent_children_rows(cg, parent_id, children_ids, - parent_cross_chunk_edge_dict, time_stamp): - """ Generates BigTable rows - - :param eh: EditHelper - :param parent_id: np.uint64 - :param children_ids: list of np.uint64s - :param parent_cross_chunk_edge_dict: dict - :param former_root_ids: list of np.uint64s - :param operation_id: np.uint64 - :param time_stamp: datetime.datetime - :return: - """ - - rows = [] - - val_dict = {} - for l, layer_edges in parent_cross_chunk_edge_dict.items(): - val_dict[column_keys.Connectivity.CrossChunkEdge[l]] = layer_edges - - assert np.max(cg.get_chunk_layers(children_ids)) < cg.get_chunk_layer( - parent_id) - - val_dict[column_keys.Hierarchy.Child] = children_ids - - rows.append(cg.mutate_row(serializers.serialize_uint64(parent_id), - val_dict, time_stamp=time_stamp)) - - for child_id in children_ids: - val_dict = {column_keys.Hierarchy.Parent: parent_id} - rows.append(cg.mutate_row(serializers.serialize_uint64(child_id), - val_dict, time_stamp=time_stamp)) - - return rows - -def propagate_edits_to_root(cg, - lvl2_dict: Dict, - lvl2_cross_chunk_edge_dict: Dict, - operation_id: np.uint64, - time_stamp: datetime.datetime): - """ Propagates changes through layers - - :param cg: ChunkedGraph instance - :param lvl2_dict: dict - maps new ids to old ids - :param lvl2_cross_chunk_edge_dict: dict - :param operation_id: np.uint64 - :param time_stamp: datetime.datetime - :return: - """ - rows = [] - - # Initialization - eh = EditHelper(cg, lvl2_dict, lvl2_cross_chunk_edge_dict) - eh.bulk_family_read() - - # Setup loop variables - layer_dict = collections.defaultdict(list) - layer_dict[2] = list(lvl2_dict.keys()) - new_root_ids = [] - # Loop over all layers up to the top - there might be layers where there is - # nothing to do - for current_layer in range(2, eh.cg.n_layers): - if len(layer_dict[current_layer]) == 0: - continue - - new_node_ids = layer_dict[current_layer] - - # Calculate connected components based on cross chunk edges ------------ - ccs, unique_graph_ids = \ - compute_cross_chunk_connected_components(eh, new_node_ids, - current_layer) - - # Build a dictionary of new connected components ----------------------- - cc_collections = collections.defaultdict(list) - for cc in ccs: - cc_node_ids = unique_graph_ids[cc] - cc_cross_edge_dict = collections.defaultdict(list) - for cc_node_id in cc_node_ids: - node_cross_edges = eh.read_cross_chunk_edges(cc_node_id) - cc_cross_edge_dict = \ - combine_cross_chunk_edge_dicts(cc_cross_edge_dict, - node_cross_edges, - start_layer=current_layer + 1) - - if (not current_layer + 1 in cc_cross_edge_dict or - len(cc_cross_edge_dict[current_layer + 1]) == 0) and \ - len(cc_node_ids) == 1: - # Skip connection - next_layer = None - for l in range(current_layer + 1, eh.cg.n_layers): - if len(cc_cross_edge_dict[l]) > 0: - next_layer = l - break - - if next_layer is None: - next_layer = eh.cg.n_layers - else: - next_layer = current_layer + 1 - - next_layer_chunk_id = eh.cg.get_parent_chunk_id_dict(cc_node_ids[0])[next_layer] - - cc_collections[next_layer_chunk_id].append( - [cc_node_ids, cc_cross_edge_dict]) - - # At this point we extracted all relevant data - now we just need to - # create the new rows -------------------------------------------------- - for next_layer_chunk_id in cc_collections: - n_ids = len(cc_collections[next_layer_chunk_id]) - new_parent_ids = eh.cg.get_unique_node_id_range(next_layer_chunk_id, - n_ids) - next_layer = eh.cg.get_chunk_layer(next_layer_chunk_id) - - for new_parent_id, cc_collection in \ - zip(new_parent_ids, cc_collections[next_layer_chunk_id]): - layer_dict[next_layer].append(new_parent_id) - eh.add_new_layer_node(new_parent_id, cc_collection[0], - cc_collection[1]) - - cc_rows = create_parent_children_rows(eh.cg, new_parent_id, - cc_collection[0], - cc_collection[1], - time_stamp) - rows.extend(cc_rows) - - if eh.cg.get_chunk_layer(next_layer_chunk_id) == eh.cg.n_layers: - new_root_ids.extend(new_parent_ids) - former_root_ids = [] - for new_parent_id in new_parent_ids: - former_root_ids.extend( - eh.get_old_node_ids(new_parent_id, eh.cg.n_layers)) - - former_root_ids = np.unique(former_root_ids) - - rl_rows = update_root_id_lineage(cg, new_parent_ids, - former_root_ids, - operation_id=operation_id, - time_stamp=time_stamp) - rows.extend(rl_rows) - - return new_root_ids, rows - - -class EditHelper(object): - def __init__(self, cg, lvl2_dict, cross_chunk_edge_dict): - """ - - :param cg: ChunkedGraph isntance - :param lvl2_dict: maps new lvl2 ids to old lvl2 ids - """ - self._cg = cg - self._lvl2_dict = lvl2_dict - - self._parent_dict = {} - self._children_dict = {} - self._cross_chunk_edge_dict = cross_chunk_edge_dict - self._new_node_ids = list(lvl2_dict.keys()) - self._old_node_dict = lvl2_dict - - @property - def cg(self): - return self._cg - - @property - def lvl2_dict(self): - return self._lvl2_dict - - @property - def old_node_dict(self): - return self._old_node_dict - - @property - def old_node_ids(self): - return np.concatenate(list(self.old_node_dict.values())) - - @property - def new_node_ids(self): - return self._new_node_ids - - def get_children(self, node_id): - """ Cache around the get_children call to the chunkedgraph - - :param node_id: np.uint64 - :return: np.uint64 - """ - if not node_id in self._children_dict: - self._children_dict[node_id] = self.cg.get_children(node_id) - for child_id in self._children_dict[node_id]: - if not child_id in self._parent_dict: - self._parent_dict[child_id] = node_id - else: - assert self._parent_dict[child_id] == node_id - - return self._children_dict[node_id] - - def get_parent(self, node_id): - """ Cache around the get_parent call to the chunkedgraph - - :param node_id: np.uint64 - :return: np.uint64 - """ - if not node_id in self._parent_dict: - self._parent_dict[node_id] = self.cg.get_parent(node_id) - - return self._parent_dict[node_id] - - def get_root(self, node_id, get_all_parents=False): - parents = [node_id] - - while self.get_parent(parents[-1]) is not None: - parents.append(self.get_parent(parents[-1])) - - if get_all_parents: - return np.array(parents) - else: - return parents[-1] - - def get_layer_children(self, node_id, layer, layer_only=False): - """ Get - - :param node_id: - :param layer: - :param layer_only: - :return: - """ - assert layer > 0 - assert layer <= self.cg.get_chunk_layer(node_id) - - if self.cg.get_chunk_layer(node_id) == layer: - return [node_id] - - layer_children_ids = [] - next_children_ids = [node_id] - - while len(next_children_ids) > 0: - next_children_id = next_children_ids[0] - del next_children_ids[0] - - children_ids = self.get_children(next_children_id) - child_id = children_ids[0] - - if self.cg.get_chunk_layer(child_id) > layer: - next_children_ids.extend(children_ids) - elif self.cg.get_chunk_layer(child_id) == layer: - layer_children_ids.extend(children_ids) - elif self.cg.get_chunk_layer(child_id) < layer and not layer_only: - layer_children_ids.extend(children_ids) - - return np.array(layer_children_ids, dtype=np.uint64) - - def get_layer_parent(self, node_id, layer, layer_only=False, - choose_lower_layer=False): - """ Gets parent in particular layer - - :param node_id: np.uint64 - :param layer: np.int - :param layer_only: bool - :param choose_lower_layer: bool - :return: - """ - assert layer >= self.cg.get_chunk_layer(node_id) - assert layer <= self.cg.n_layers - - if self.cg.get_chunk_layer(node_id) == layer: - return [node_id] - - layer_parent_ids = [] - next_parent_ids = [node_id] - - while len(next_parent_ids) > 0: - next_parent_id = next_parent_ids[0] - del next_parent_ids[0] - - parent_id = self.get_parent(next_parent_id) - - if parent_id is None: - raise() - - if self.cg.get_chunk_layer(parent_id) < layer: - next_parent_ids.append(parent_id) - elif self.cg.get_chunk_layer(parent_id) == layer: - layer_parent_ids.append(parent_id) - elif self.cg.get_chunk_layer(parent_id) > layer and not layer_only: - if choose_lower_layer: - layer_parent_ids.append(next_parent_id) - else: - layer_parent_ids.append(parent_id) - - return layer_parent_ids - - def _get_lower_old_node_ids(self, node_id): - if not node_id in self._new_node_ids: - return [] - elif node_id in self._old_node_dict: - return self._old_node_dict[node_id] - else: - assert self.cg.get_chunk_layer(node_id) > 1 - - old_node_ids = [] - for child_id in self.get_children(node_id): - old_node_ids.extend(self._get_lower_old_node_ids(child_id)) - - return np.unique(old_node_ids) - - def get_old_node_ids(self, node_id, layer): - """ Acquires old node ids for new node id - - :param node_id: np.uint64 - :param layer: np.int - :return: - """ - lower_old_node_ids = self._get_lower_old_node_ids(node_id) - - old_node_ids = [] - for lower_old_node_id in lower_old_node_ids: - old_node_ids.extend(self.get_layer_parent(lower_old_node_id, layer, - choose_lower_layer=True)) - - old_node_ids = np.unique(old_node_ids) - return old_node_ids - - def read_cross_chunk_edges(self, node_id): - """ Cache around the read_cross_chunk_edges call to the chunkedgraph - - :param node_id: np.uint64 - :return: dict - """ - if not node_id in self._cross_chunk_edge_dict: - self._cross_chunk_edge_dict[node_id] = \ - self.cg.read_cross_chunk_edges(node_id) - return self._cross_chunk_edge_dict[node_id] - - def bulk_family_read(self): - """ Caches parent and children information that will be needed later """ - def _get_root_thread(lvl2_node_id): - p_ids = self.cg.get_root(lvl2_node_id, get_all_parents=True) - p_ids = np.concatenate([[lvl2_node_id], p_ids]) - - for i_parent in range(len(p_ids) - 1): - self._parent_dict[p_ids[i_parent]] = p_ids[i_parent+1] - - def _read_cc_edges_thread(node_ids): - for node_id in node_ids: - if self.cg.get_chunk_layer(node_id) == self.cg.n_layers: - continue - - self.read_cross_chunk_edges(node_id) - - lvl2_node_ids = [] - for v in self.lvl2_dict.values(): - lvl2_node_ids.extend(v) - - mu.multithread_func(_get_root_thread, lvl2_node_ids, - n_threads=len(lvl2_node_ids), debug=False) - - parent_ids = list(self._parent_dict.values()) - child_dict = self.cg.get_children(parent_ids, flatten=False) - node_ids = [] - - for parent_id in child_dict: - self._children_dict[parent_id] = child_dict[parent_id] - - if self.cg.get_chunk_layer(parent_id) > 2: - node_ids.extend(child_dict[parent_id]) - - node_ids.append(parent_id) - - for child_id in self._children_dict[parent_id]: - if not child_id in self._parent_dict: - self._parent_dict[child_id] = parent_id - else: - assert self._parent_dict[child_id] == parent_id - - node_ids = np.unique(node_ids) - n_threads = int(np.ceil(len(node_ids) / 5)) - - if len(node_ids) > 0: - node_id_blocks = np.array_split(node_ids, n_threads) - mu.multithread_func(_read_cc_edges_thread, node_id_blocks, - n_threads=n_threads, debug=False) - - def bulk_cross_chunk_edge_read(self): - raise NotImplementedError - - def add_new_layer_node(self, node_id, children_ids, cross_chunk_edge_dict): - """ Adds a new node to the helper infrastructure - - :param node_id: np.uint64 - :param children_ids: list of np.uint64s - :param cross_chunk_edge_dict: dict - :return: - """ - self._cross_chunk_edge_dict[node_id] = cross_chunk_edge_dict - - self._children_dict[node_id] = children_ids - for child_id in children_ids: - self._parent_dict[child_id] = node_id - - self._new_node_ids.append(node_id) - layer = self.cg.get_chunk_layer(node_id) - self._old_node_dict[node_id] = self.get_old_node_ids(node_id, layer) diff --git a/pychunkedgraph/backend/chunkedgraph_utils.py b/pychunkedgraph/backend/chunkedgraph_utils.py deleted file mode 100644 index 5c2412250..000000000 --- a/pychunkedgraph/backend/chunkedgraph_utils.py +++ /dev/null @@ -1,231 +0,0 @@ -import datetime -from typing import Dict, Iterable, Optional, Union - -import numpy as np -import pandas as pd - -from google.cloud import bigtable -from google.cloud.bigtable.row_filters import TimestampRange, \ - TimestampRangeFilter, ColumnRangeFilter, RowFilterChain, \ - RowFilterUnion, RowFilter -from pychunkedgraph.backend.utils import column_keys, serializers - - -def compute_indices_pandas(data) -> pd.Series: - """ Computes indices of all unique entries - - Make sure to remap your array to a dense range starting at zero - - https://stackoverflow.com/questions/33281957/faster-alternative-to-numpy-where - - :param data: np.ndarray - :return: pandas dataframe - """ - d = data.ravel() - f = lambda x: np.unravel_index(x.index, data.shape) - return pd.Series(d).groupby(d).apply(f) - - -def log_n(arr, n): - """ Computes log to base n - - :param arr: array or float - :param n: int - base - :return: return log_n(arr) - """ - if n == 2: - return np.log2(arr) - elif n == 10: - return np.log10(arr) - else: - return np.log(arr) / np.log(n) - - -def compute_bitmasks(n_layers: int, fan_out: int, s_bits_atomic_layer: int = 8 - ) -> Dict[int, int]: - """ Computes the bitmasks for each layer. A bitmasks encodes how many bits - are used to store the chunk id in each dimension. The smallest number of - bits needed to encode this information is chosen. The layer id is always - encoded with 8 bits as this information is required a priori. - - Currently, encoding of layer 1 is fixed to 8 bits. - - :param n_layers: int - :param fan_out: int - :param s_bits_atomic_layer: int - :return: dict - layer -> bits for layer id - """ - - bitmask_dict = {} - for i_layer in range(n_layers, 0, -1): - layer_exp = n_layers - i_layer - n_bits_for_layers = max(1, np.ceil(log_n(fan_out**layer_exp, fan_out))) - - if i_layer == 1: - n_bits_for_layers = np.max([s_bits_atomic_layer, n_bits_for_layers]) - - n_bits_for_layers = int(n_bits_for_layers) - - # assert n_bits_for_layers <= 8 - - bitmask_dict[i_layer] = n_bits_for_layers - - # print(f"Bitmasks: {bitmask_dict}") - - return bitmask_dict - - -def get_google_compatible_time_stamp(time_stamp: datetime.datetime, - round_up: bool =False - ) -> datetime.datetime: - """ Makes a datetime.datetime time stamp compatible with googles' services. - Google restricts the accuracy of time stamps to milliseconds. Hence, the - microseconds are cut of. By default, time stamps are rounded to the lower - number. - - :param time_stamp: datetime.datetime - :param round_up: bool - :return: datetime.datetime - """ - - micro_s_gap = datetime.timedelta(microseconds=time_stamp.microsecond % 1000) - - if micro_s_gap == 0: - return time_stamp - - if round_up: - time_stamp += (datetime.timedelta(microseconds=1000) - micro_s_gap) - else: - time_stamp -= micro_s_gap - - return time_stamp - - -def get_column_filter( - columns: Union[Iterable[column_keys._Column], column_keys._Column] = None) -> RowFilter: - """ Generates a RowFilter that accepts the specified columns """ - - if isinstance(columns, column_keys._Column): - return ColumnRangeFilter(columns.family_id, - start_column=columns.key, - end_column=columns.key) - elif len(columns) == 1: - return ColumnRangeFilter(columns[0].family_id, - start_column=columns[0].key, - end_column=columns[0].key) - - return RowFilterUnion([ColumnRangeFilter(col.family_id, - start_column=col.key, - end_column=col.key) for col in columns]) - - -def get_time_range_filter( - start_time: Optional[datetime.datetime] = None, - end_time: Optional[datetime.datetime] = None, - end_inclusive: bool = True) -> RowFilter: - """ Generates a TimeStampRangeFilter which is inclusive for start and (optionally) end. - - :param start: - :param end: - :return: - """ - # Comply to resolution of BigTables TimeRange - if start_time is not None: - start_time = get_google_compatible_time_stamp(start_time, round_up=False) - if end_time is not None: - end_time = get_google_compatible_time_stamp(end_time, round_up=end_inclusive) - - return TimestampRangeFilter(TimestampRange(start=start_time, end=end_time)) - - -def get_time_range_and_column_filter( - columns: Optional[Union[Iterable[column_keys._Column], column_keys._Column]] = None, - start_time: Optional[datetime.datetime] = None, - end_time: Optional[datetime.datetime] = None, - end_inclusive: bool = False) -> RowFilter: - - time_filter = get_time_range_filter(start_time=start_time, - end_time=end_time, - end_inclusive=end_inclusive) - - if columns is not None: - column_filter = get_column_filter(columns) - return RowFilterChain([column_filter, time_filter]) - else: - return time_filter - - -def get_max_time(): - """ Returns the (almost) max time in datetime.datetime - - :return: datetime.datetime - """ - return datetime.datetime(9999, 12, 31, 23, 59, 59, 0) - - -def get_min_time(): - """ Returns the min time in datetime.datetime - - :return: datetime.datetime - """ - return datetime.datetime.strptime("01/01/00 00:00", "%d/%m/%y %H:%M") - - -def combine_cross_chunk_edge_dicts(d1, d2, start_layer=2): - """ Combines two cross chunk dictionaries - Cross chunk dictionaries contain a layer id -> edge list mapping. - - :param d1: dict - :param d2: dict - :param start_layer: int - :return: dict - """ - assert start_layer >= 2 - - new_d = {} - - for l in d2: - if l < start_layer: - continue - - layers = np.unique(list(d1.keys()) + list(d2.keys())) - layers = layers[layers >= start_layer] - - for l in layers: - if l in d1 and l in d2: - new_d[l] = np.concatenate([d1[l].reshape(-1, 2), - d2[l].reshape(-1, 2)]) - elif l in d1: - new_d[l] = d1[l].reshape(-1, 2) - elif l in d2: - new_d[l] = d2[l].reshape(-1, 2) - else: - raise Exception() - - edges_flattened_view = new_d[l].view(dtype='u8,u8') - m = np.unique(edges_flattened_view, return_index=True)[1] - new_d[l] = new_d[l][m] - - return new_d - - -def time_min(): - """ Returns a minimal time stamp that still works with google - - :return: datetime.datetime - """ - return datetime.datetime.strptime("01/01/00 00:00", "%d/%m/%y %H:%M") - - -def partial_row_data_to_column_dict(partial_row_data: bigtable.row_data.PartialRowData) \ - -> Dict[column_keys._Column, bigtable.row_data.PartialRowData]: - new_column_dict = {} - - for family_id, column_dict in partial_row_data._cells.items(): - for column_key, column_values in column_dict.items(): - column = column_keys.from_key(family_id, column_key) - new_column_dict[column] = column_values - - return new_column_dict diff --git a/pychunkedgraph/backend/cutting.py b/pychunkedgraph/backend/cutting.py deleted file mode 100644 index 07f1cf6d3..000000000 --- a/pychunkedgraph/backend/cutting.py +++ /dev/null @@ -1,489 +0,0 @@ -import collections -import numpy as np -import networkx as nx -import itertools -import logging -from networkx.algorithms.flow import shortest_augmenting_path, edmonds_karp, preflow_push -from networkx.algorithms.connectivity import minimum_st_edge_cut -import time -import graph_tool -import graph_tool.flow - -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union - -from pychunkedgraph.backend import flatgraph_utils -from pychunkedgraph.backend import chunkedgraph_exceptions as cg_exceptions - -float_max = np.finfo(np.float32).max -DEBUG_MODE = False - -def merge_cross_chunk_edges(edges: Iterable[Sequence[np.uint64]], - affs: Sequence[np.uint64], - logger: Optional[logging.Logger] = None): - """ Merges cross chunk edges - :param edges: n x 2 array of uint64s - :param affs: float array of length n - :return: - """ - # mask for edges that have to be merged - cross_chunk_edge_mask = np.isinf(affs) - - # graph with edges that have to be merged - cross_chunk_graph = nx.Graph() - cross_chunk_graph.add_edges_from(edges[cross_chunk_edge_mask]) - - # connected components in this graph will be combined in one component - ccs = nx.connected_components(cross_chunk_graph) - - # Build mapping - # For each connected component the smallest node id is chosen to be the - # representative. - remapping = {} - mapping_ks = [] - mapping_vs = [] - - for cc in ccs: - nodes = np.array(list(cc)) - rep_node = np.min(nodes) - - remapping[rep_node] = nodes - mapping_ks.extend(nodes) - mapping_vs.extend([rep_node] * len(nodes)) - - # Initialize mapping with a each node mapping to itself, then update - # those edges merged to one across chunk boundaries. - # u_nodes = np.unique(edges) - # mapping = dict(zip(u_nodes, u_nodes)) - # mapping.update(dict(zip(mapping_ks, mapping_vs))) - mapping = dict(zip(mapping_ks, mapping_vs)) - - # Vectorize remapping - mapping_vec = np.vectorize(lambda a : mapping[a] if a in mapping else a) - remapped_edges = mapping_vec(edges) - - # Remove cross chunk edges - remapped_edges = remapped_edges[~cross_chunk_edge_mask] - remapped_affs = affs[~cross_chunk_edge_mask] - - return remapped_edges, remapped_affs, mapping, remapping - - -def merge_cross_chunk_edges_graph_tool(edges: Iterable[Sequence[np.uint64]], - affs: Sequence[np.uint64], - logger: Optional[logging.Logger] = None): - """ Merges cross chunk edges - :param edges: n x 2 array of uint64s - :param affs: float array of length n - :return: - """ - - # mask for edges that have to be merged - cross_chunk_edge_mask = np.isinf(affs) - - # graph with edges that have to be merged - graph, _, _, unique_ids = flatgraph_utils.build_gt_graph( - edges[cross_chunk_edge_mask], make_directed=True) - - # connected components in this graph will be combined in one component - ccs = flatgraph_utils.connected_components(graph) - - remapping = {} - mapping = np.array([], dtype=np.uint64).reshape(-1, 2) - - for cc in ccs: - nodes = unique_ids[cc] - rep_node = np.min(nodes) - - remapping[rep_node] = nodes - - rep_nodes = np.ones(len(nodes), dtype=np.uint64).reshape(-1, 1) * rep_node - m = np.concatenate([nodes.reshape(-1, 1), rep_nodes], axis=1) - - mapping = np.concatenate([mapping, m], axis=0) - - u_nodes = np.unique(edges) - u_unmapped_nodes = u_nodes[~np.in1d(u_nodes, mapping)] - - unmapped_mapping = np.concatenate([u_unmapped_nodes.reshape(-1, 1), - u_unmapped_nodes.reshape(-1, 1)], axis=1) - mapping = np.concatenate([mapping, unmapped_mapping], axis=0) - - sort_idx = np.argsort(mapping[:, 0]) - idx = np.searchsorted(mapping[:, 0], edges, sorter=sort_idx) - remapped_edges = np.asarray(mapping[:, 1])[sort_idx][idx] - - remapped_edges = remapped_edges[~cross_chunk_edge_mask] - remapped_affs = affs[~cross_chunk_edge_mask] - - return remapped_edges, remapped_affs, mapping, remapping - - -def mincut_nx(edges: Iterable[Sequence[np.uint64]], affs: Sequence[np.uint64], - sources: Sequence[np.uint64], sinks: Sequence[np.uint64], - logger: Optional[logging.Logger] = None) -> np.ndarray: - """ Computes the min cut on a local graph - :param edges: n x 2 array of uint64s - :param affs: float array of length n - :param sources: uint64 - :param sinks: uint64 - :return: m x 2 array of uint64s - edges that should be removed - """ - - time_start = time.time() - - original_edges = edges.copy() - - edges, affs, mapping, remapping = merge_cross_chunk_edges(edges.copy(), - affs.copy()) - mapping_vec = np.vectorize(lambda a: mapping[a] if a in mapping else a) - - if len(edges) == 0: - return [] - - if len(mapping) > 0: - assert np.unique(list(mapping.keys()), return_counts=True)[1].max() == 1 - - remapped_sinks = mapping_vec(sinks) - remapped_sources = mapping_vec(sources) - - sinks = remapped_sinks - sources = remapped_sources - - sink_connections = np.array(list(itertools.product(sinks, sinks))) - source_connections = np.array(list(itertools.product(sources, sources))) - - weighted_graph = nx.Graph() - weighted_graph.add_edges_from(edges) - weighted_graph.add_edges_from(sink_connections) - weighted_graph.add_edges_from(source_connections) - - for i_edge, edge in enumerate(edges): - weighted_graph[edge[0]][edge[1]]['capacity'] = affs[i_edge] - weighted_graph[edge[1]][edge[0]]['capacity'] = affs[i_edge] - - # Add infinity edges for multicut - for sink_i in sinks: - for sink_j in sinks: - weighted_graph[sink_i][sink_j]['capacity'] = float_max - - for source_i in sources: - for source_j in sources: - weighted_graph[source_i][source_j]['capacity'] = float_max - - - dt = time.time() - time_start - if logger is not None: - logger.debug("Graph creation: %.2fms" % (dt * 1000)) - time_start = time.time() - - ccs = list(nx.connected_components(weighted_graph)) - - for cc in ccs: - cc_list = list(cc) - - # If connected component contains no sources and/or no sinks, - # remove its nodes from the mincut computation - if not np.any(np.in1d(sources, cc_list)) or \ - not np.any(np.in1d(sinks, cc_list)): - weighted_graph.remove_nodes_from(cc) - - r_flow = edmonds_karp(weighted_graph, sinks[0], sources[0]) - cutset = minimum_st_edge_cut(weighted_graph, sources[0], sinks[0], - residual=r_flow) - - # cutset = nx.minimum_edge_cut(weighted_graph, sources[0], sinks[0], flow_func=edmonds_karp) - - dt = time.time() - time_start - if logger is not None: - logger.debug("Mincut comp: %.2fms" % (dt * 1000)) - - if cutset is None: - return [] - - time_start = time.time() - - edge_cut = list(list(cutset)) - - weighted_graph.remove_edges_from(edge_cut) - ccs = list(nx.connected_components(weighted_graph)) - - # assert len(ccs) == 2 - - for cc in ccs: - cc_list = list(cc) - if logger is not None: - logger.debug("CC size = %d" % len(cc_list)) - - if np.any(np.in1d(sources, cc_list)): - assert np.all(np.in1d(sources, cc_list)) - assert ~np.any(np.in1d(sinks, cc_list)) - - if np.any(np.in1d(sinks, cc_list)): - assert np.all(np.in1d(sinks, cc_list)) - assert ~np.any(np.in1d(sources, cc_list)) - - dt = time.time() - time_start - if logger is not None: - logger.debug("Splitting local graph: %.2fms" % (dt * 1000)) - - remapped_cutset = [] - for cut in cutset: - if cut[0] in remapping: - pre_cut = remapping[cut[0]] - else: - pre_cut = [cut[0]] - - if cut[1] in remapping: - post_cut = remapping[cut[1]] - else: - post_cut = [cut[1]] - - remapped_cutset.extend(list(itertools.product(pre_cut, post_cut))) - remapped_cutset.extend(list(itertools.product(post_cut, pre_cut))) - - remapped_cutset = np.array(remapped_cutset, dtype=np.uint64) - - remapped_cutset_flattened_view = remapped_cutset.view(dtype='u8,u8') - edges_flattened_view = original_edges.view(dtype='u8,u8') - - cutset_mask = np.in1d(remapped_cutset_flattened_view, edges_flattened_view) - - return remapped_cutset[cutset_mask] - - -# TODO: Refactor/break up this long function into several functions -def mincut_graph_tool(edges: Iterable[Sequence[np.uint64]], - affs: Sequence[np.uint64], - sources: Sequence[np.uint64], - sinks: Sequence[np.uint64], - logger: Optional[logging.Logger] = None) -> np.ndarray: - """ Computes the min cut on a local graph - :param edges: n x 2 array of uint64s - :param affs: float array of length n - :param sources: uint64 - :param sinks: uint64 - :return: m x 2 array of uint64s - edges that should be removed - """ - time_start = time.time() - - original_edges = edges - - # Stitch supervoxels across chunk boundaries and represent those that are - # connected with a cross chunk edge with a single id. This may cause id - # changes among sinks and sources that need to be taken care of. - edges, affs, mapping, remapping = merge_cross_chunk_edges(edges.copy(), - affs.copy()) - - dt = time.time() - time_start - if logger is not None: - logger.debug("Cross edge merging: %.2fms" % (dt * 1000)) - time_start = time.time() - - mapping_vec = np.vectorize(lambda a: mapping[a] if a in mapping else a) - - if len(edges) == 0: - return [] - - if len(mapping) > 0: - assert np.unique(list(mapping.keys()), return_counts=True)[1].max() == 1 - - remapped_sinks = mapping_vec(sinks) - remapped_sources = mapping_vec(sources) - - sinks = remapped_sinks - sources = remapped_sources - - # Assemble edges: Edges after remapping combined with edges between sinks - # and sources - sink_edges = list(itertools.product(sinks, sinks)) - source_edges = list(itertools.product(sources, sources)) - - comb_edges = np.concatenate([edges, sink_edges, source_edges]) - - comb_affs = np.concatenate([affs, [float_max, ] * - (len(sink_edges) + len(source_edges))]) - - # To make things easier for everyone involved, we map the ids to - # [0, ..., len(unique_ids) - 1] - # Generate weighted graph with graph_tool - weighted_graph, cap, gt_edges, unique_ids = \ - flatgraph_utils.build_gt_graph(comb_edges, comb_affs, - make_directed=True) - - # Create an edge property to remove edges later (will be used to test whether split valid) - is_fake_edge = np.concatenate( - [[False] * len(affs), [True] * (len(sink_edges) + len(source_edges))] - ) - remove_edges_later = np.concatenate([is_fake_edge, is_fake_edge]) - edges_to_remove = weighted_graph.new_edge_property("bool", vals=remove_edges_later) - - sink_graph_ids = np.where(np.in1d(unique_ids, sinks))[0] - source_graph_ids = np.where(np.in1d(unique_ids, sources))[0] - - if logger is not None: - logger.debug(f"{sinks}, {sink_graph_ids}") - logger.debug(f"{sources}, {source_graph_ids}") - - dt = time.time() - time_start - if logger is not None: - logger.debug("Graph creation: %.2fms" % (dt * 1000)) - time_start = time.time() - - # Get rid of connected components that are not involved in the local - # mincut - ccs = flatgraph_utils.connected_components(weighted_graph) - - removed = weighted_graph.new_vertex_property("bool") - removed.a = False - if len(ccs) > 1: - for cc in ccs: - # If connected component contains no sources or no sinks, - # remove its nodes from the mincut computation - if not (np.any(np.in1d(source_graph_ids, cc)) and \ - np.any(np.in1d(sink_graph_ids, cc))): - for node_id in cc: - removed[node_id] = True - - weighted_graph.set_vertex_filter(removed, inverted=True) - - # Somewhat untuitively, we need to create a new pruned graph for the following - # connected components call to work correctly, because the vertex filter - # only labels the graph and the filtered vertices still show up after running - # graph_tool.label_components - pruned_graph = graph_tool.Graph(weighted_graph, prune=True) - - # Test that there is only one connected component left - ccs = flatgraph_utils.connected_components(pruned_graph) - - if len(ccs) > 1: - logger.warning("Not all sinks and sources are within the same (local)" - "connected component") - raise cg_exceptions.PreconditionError( - "Not all sinks and sources are within the same (local)" - "connected component" - ) - elif len(ccs) == 0: - raise cg_exceptions.PreconditionError( - "Sinks and sources are not connected through the local graph. " - "Please try a different set of vertices to perform the mincut." - ) - - # Compute mincut - src, tgt = weighted_graph.vertex(source_graph_ids[0]), \ - weighted_graph.vertex(sink_graph_ids[0]) - - res = graph_tool.flow.push_relabel_max_flow(weighted_graph, src, tgt, cap) - - part = graph_tool.flow.min_st_cut(weighted_graph, src, cap, res) - - labeled_edges = part.a[gt_edges] - cut_edge_set = gt_edges[labeled_edges[:, 0] != labeled_edges[:, 1]] - - dt = time.time() - time_start - if logger is not None: - logger.debug("Mincut comp: %.2fms" % (dt * 1000)) - time_start = time.time() - - if len(cut_edge_set) == 0: - return [] - - time_start = time.time() - - if DEBUG_MODE: - # These assertions should not fail. If they do, - # then something went wrong with the graph_tool mincut computation - for i_cc in np.unique(part.a): - # Make sure to read real ids and not graph ids - cc_list = unique_ids[np.array(np.where(part.a == i_cc)[0], - dtype=np.int)] - - if np.any(np.in1d(sources, cc_list)): - assert np.all(np.in1d(sources, cc_list)) - assert ~np.any(np.in1d(sinks, cc_list)) - - if np.any(np.in1d(sinks, cc_list)): - assert np.all(np.in1d(sinks, cc_list)) - assert ~np.any(np.in1d(sources, cc_list)) - - weighted_graph.clear_filters() - - for cut_edge in cut_edge_set: - # May be more than one edge from vertex cut_edge[0] to vertex cut_edge[1], add them all - parallel_edges = weighted_graph.edge(cut_edge[0], cut_edge[1], all_edges=True) - for edge_to_remove in parallel_edges: - edges_to_remove[edge_to_remove] = True - - weighted_graph.set_filters(edges_to_remove, removed, True, True) - - ccs_test_post_cut = flatgraph_utils.connected_components(weighted_graph) - - # Make sure sinks and sources are among each other and not in different sets - # after removing the cut edges and the fake infinity edges - try: - for cc in ccs_test_post_cut: - if np.any(np.in1d(source_graph_ids, cc)): - assert np.all(np.in1d(source_graph_ids, cc)) - assert ~np.any(np.in1d(sink_graph_ids, cc)) - - if np.any(np.in1d(sink_graph_ids, cc)): - assert np.all(np.in1d(sink_graph_ids, cc)) - assert ~np.any(np.in1d(source_graph_ids, cc)) - except AssertionError: - raise cg_exceptions.PreconditionError( - "Failed to find a cut that separated the sources from the sinks. " - "Please try another cut that partitions the sets cleanly if possible. " - "If there is a clear path between all the supervoxels in each set, " - "that helps the mincut algorithm." - ) - - dt = time.time() - time_start - if logger is not None: - logger.debug("Verifying local graph: %.2fms" % (dt * 1000)) - - # Extract original ids - # This has potential to be optimized - remapped_cutset = [] - for s, t in flatgraph_utils.remap_ids_from_graph(cut_edge_set, unique_ids): - - if s in remapping: - s = remapping[s] - else: - s = [s] - - if t in remapping: - t = remapping[t] - else: - t = [t] - - remapped_cutset.extend(list(itertools.product(s, t))) - remapped_cutset.extend(list(itertools.product(t, s))) - - remapped_cutset = np.array(remapped_cutset, dtype=np.uint64) - - remapped_cutset_flattened_view = remapped_cutset.view(dtype='u8,u8') - edges_flattened_view = original_edges.view(dtype='u8,u8') - - cutset_mask = np.in1d(remapped_cutset_flattened_view, edges_flattened_view) - - return remapped_cutset[cutset_mask] - - -def mincut(edges: Iterable[Sequence[np.uint64]], - affs: Sequence[np.uint64], - sources: Sequence[np.uint64], - sinks: Sequence[np.uint64], - logger: Optional[logging.Logger] = None) -> np.ndarray: - """ Computes the min cut on a local graph - :param edges: n x 2 array of uint64s - :param affs: float array of length n - :param sources: uint64 - :param sinks: uint64 - :return: m x 2 array of uint64s - edges that should be removed - """ - - return mincut_graph_tool(edges=edges, affs=affs, sources=sources, - sinks=sinks, logger=logger) - diff --git a/pychunkedgraph/backend/cutting_test.py b/pychunkedgraph/backend/cutting_test.py deleted file mode 100644 index 4007872d4..000000000 --- a/pychunkedgraph/backend/cutting_test.py +++ /dev/null @@ -1,130 +0,0 @@ -import networkx as nx -import numpy as np -import graph_tool.all, graph_tool.flow, graph_tool.topology -from networkx.algorithms.flow import shortest_augmenting_path, edmonds_karp -from networkx.algorithms.connectivity import minimum_st_edge_cut -import logging -import sys -import time - -from pychunkedgraph.backend import cutting - - -def ex_graph(): - w_n = .8 - w_c = .5 - w_l = .2 - w_h = 1. - inf = np.finfo(np.float32).max - - edgelist = [ - [1, 2, w_n], - [1, 3, w_l], - [4, 7, w_l], - [6, 9, w_l], - [2, 4, w_l], - [2, 5, w_l], - [2, 3, w_n], - [8, 9, w_c], - [3, 5, w_l], - [3, 6, w_c], - # [4, 5, w_l], - [5, 6, w_l], - [7, 8, w_n], - [7, 10, w_n], - [8, 10, w_n], - [8, 11, w_n], - [9, 11, w_n], - [10, 12, w_n], - [11, 12, w_n], - [4, 5, inf] - ] - - edgelist = np.array(edgelist) - - edges = edgelist[:, :2].astype(np.int) - 1 - weights = edgelist[:, 2].astype(np.float) - - n_nodes = 100000 - edges = np.unique(np.sort(np.random.randint(0, n_nodes, n_nodes*5).reshape(-1, 2), axis=1), axis=0) - weights = np.random.rand(len(edges)) - - if not len(np.unique(edges) == 12): - edges, weights = ex_graph() - - edges += 100 - - return edges.astype(np.uint64), weights - - -def test_raw(): - edges, weights = ex_graph() - - weighted_graph = nx.from_edgelist(edges) - for i_edge, edge in enumerate(edges): - weighted_graph[edge[0]][edge[1]]['capacity'] = weights[i_edge] - - r_flow = edmonds_karp(weighted_graph, 0 , 11) - cutset = minimum_st_edge_cut(weighted_graph, 0, 11, residual=r_flow) - - weighted_graph.remove_edges_from(cutset) - ccs = list(nx.connected_components(weighted_graph)) - - print("NETWORKX:", cutset) - print("NETWORKX:", ccs) - - g = graph_tool.all.Graph(directed=True) - g.add_edge_list(edge_list=np.concatenate([edges, edges[:, [1, 0]]]), hashed=False) - cap = g.new_edge_property("float", vals=np.concatenate([weights, weights])) - src, tgt = g.vertex(0), g.vertex(11) - - res = graph_tool.flow.boykov_kolmogorov_max_flow(g, src, tgt, cap) - - part = graph_tool.all.min_st_cut(g, src, cap, res) - rm_edges = [e for e in g.edges() if part[e.source()] != part[e.target()]] - - print("GRAPHTOOL:", [(rm_edge.source().__str__(), rm_edge.target().__str__()) for rm_edge in rm_edges]) - - ccs = [] - for i_cc in np.unique(part.a): - ccs.append(np.where(part.a == i_cc)[0]) - - print("GRAPHTOOL:", ccs) - - return edges, weights - - -def test_imp(): - edges, weights = ex_graph() - - logger = logging.getLogger("%d" % np.random.randint(0, 100000000)) - logger.setLevel(logging.DEBUG) - sh = logging.StreamHandler(sys.stdout) - sh.setLevel(logging.DEBUG) - logger.addHandler(sh) - - # print(edges) - # print(weights) - - sources = np.unique(edges)[:10] - sinks = np.unique(edges)[-10:] - - time_start = time.time() - out_gt = cutting.mincut_graph_tool(edges, weights, sources, sinks, logger=logger) - time_gt = time.time() - time_start - - # print(out_gt) - - print("----------------") - - time_start = time.time() - out_nx = cutting.mincut_nx(edges, weights, sources, sinks, logger=logger) - time_nx = time.time() - time_start - - # print(out_nx) - - print("Time networkx: %.3fs" % (time_nx)) - print("Time graph_tool: %.3fs" % (time_gt)) - - return np.array_equal(np.unique(out_nx, axis=0), np.unique(out_gt, axis=0)) - diff --git a/pychunkedgraph/backend/flatgraph_utils.py b/pychunkedgraph/backend/flatgraph_utils.py deleted file mode 100644 index 2832db5ef..000000000 --- a/pychunkedgraph/backend/flatgraph_utils.py +++ /dev/null @@ -1,67 +0,0 @@ -import numpy as np -import graph_tool -from graph_tool import topology - - -def build_gt_graph(edges, weights=None, is_directed=True, make_directed=False, - hashed=False): - """ Builds a graph_tool graph - - :param edges: n x 2 numpy array - :param weights: numpy array of length n - :param is_directed: bool - :param make_directed: bool - :param hashed: bool - :return: graph, capacities - """ - if weights is not None: - assert len(weights) == len(edges) - weights = np.array(weights) - - unique_ids, edges = np.unique(edges, return_inverse=True) - edges = edges.reshape(-1, 2) - - edges = np.array(edges) - - if make_directed: - is_directed = True - edges = np.concatenate([edges, edges[:, [1, 0]]]) - - if weights is not None: - weights = np.concatenate([weights, weights]) - - weighted_graph = graph_tool.Graph(directed=is_directed) - weighted_graph.add_edge_list(edge_list=edges, hashed=hashed) - - if weights is not None: - cap = weighted_graph.new_edge_property("float", vals=weights) - else: - cap = None - - return weighted_graph, cap, edges, unique_ids - - -def remap_ids_from_graph(graph_ids, unique_ids): - return unique_ids[graph_ids] - - -def connected_components(graph): - """ Computes connected components of graph_tool graph - - :param graph: graph_tool.Graph - :return: np.array of len == number of nodes - """ - assert isinstance(graph, graph_tool.Graph) - - cc_labels = topology.label_components(graph)[0].a - - if len(cc_labels) == 0: - return [] - - idx_sort = np.argsort(cc_labels) - vals, idx_start, count = np.unique(cc_labels[idx_sort], return_counts=True, - return_index=True) - - res = np.split(idx_sort, idx_start[1:]) - - return res \ No newline at end of file diff --git a/pychunkedgraph/backend/graphoperation.py b/pychunkedgraph/backend/graphoperation.py deleted file mode 100644 index 6537598cf..000000000 --- a/pychunkedgraph/backend/graphoperation.py +++ /dev/null @@ -1,799 +0,0 @@ -import itertools -from abc import ABC, abstractmethod -from collections import namedtuple -from datetime import datetime -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, Union - -import numpy as np - -from pychunkedgraph.backend import chunkedgraph_edits as cg_edits -from pychunkedgraph.backend import chunkedgraph_exceptions as cg_exceptions -from pychunkedgraph.backend.root_lock import RootLock -from pychunkedgraph.backend.utils import basetypes, column_keys, serializers - -if TYPE_CHECKING: - from pychunkedgraph.backend.chunkedgraph import ChunkedGraph - from google.cloud import bigtable - - -class GraphEditOperation(ABC): - __slots__ = ["cg", "user_id", "source_coords", "sink_coords"] - Result = namedtuple("Result", ["operation_id", "new_root_ids", "new_lvl2_ids"]) - - def __init__( - self, - cg: "ChunkedGraph", - *, - user_id: str, - source_coords: Optional[Sequence[Sequence[np.int]]] = None, - sink_coords: Optional[Sequence[Sequence[np.int]]] = None, - ) -> None: - super().__init__() - self.cg = cg - self.user_id = user_id - self.source_coords = None - self.sink_coords = None - - if source_coords is not None: - self.source_coords = np.atleast_2d(source_coords).astype(basetypes.COORDINATES) - if self.source_coords.size == 0: - self.source_coords = None - if sink_coords is not None: - self.sink_coords = np.atleast_2d(sink_coords).astype(basetypes.COORDINATES) - if self.sink_coords.size == 0: - self.sink_coords = None - - @classmethod - def _resolve_undo_chain( - cls, - cg: "ChunkedGraph", - *, - user_id: str, - operation_id: np.uint64, - is_undo: bool, - multicut_as_split: bool, - ): - log_record = cg.read_log_row(operation_id) - log_record_type = cls.get_log_record_type(log_record) - - while log_record_type in (RedoOperation, UndoOperation): - if log_record_type is RedoOperation: - operation_id = log_record[column_keys.OperationLogs.RedoOperationID] - else: - is_undo = not is_undo - operation_id = log_record[column_keys.OperationLogs.UndoOperationID] - log_record = cg.read_log_row(operation_id) - log_record_type = cls.get_log_record_type(log_record) - - if is_undo: - return UndoOperation( - cg, - user_id=user_id, - superseded_operation_id=operation_id, - multicut_as_split=multicut_as_split, - ) - else: - return RedoOperation( - cg, - user_id=user_id, - superseded_operation_id=operation_id, - multicut_as_split=multicut_as_split, - ) - - @staticmethod - def get_log_record_type( - log_record: Dict[column_keys._Column, Union[np.ndarray, np.number]], - *, - multicut_as_split=True, - ) -> Type["GraphEditOperation"]: - """Guesses the type of GraphEditOperation given a log record dictionary. - :param log_record: log record dictionary - :type log_record: Dict[column_keys._Column, Union[np.ndarray, np.number]] - :param multicut_as_split: If true, treat MulticutOperation as SplitOperation - - :return: The type of the matching GraphEditOperation subclass - :rtype: Type["GraphEditOperation"] - """ - if column_keys.OperationLogs.UndoOperationID in log_record: - return UndoOperation - if column_keys.OperationLogs.RedoOperationID in log_record: - return RedoOperation - if column_keys.OperationLogs.AddedEdge in log_record: - return MergeOperation - if column_keys.OperationLogs.RemovedEdge in log_record: - if multicut_as_split or column_keys.OperationLogs.BoundingBoxOffset not in log_record: - return SplitOperation - return MulticutOperation - raise TypeError(f"Could not determine graph operation type.") - - @classmethod - def from_log_record( - cls, - cg: "ChunkedGraph", - log_record: Dict[column_keys._Column, Union[np.ndarray, np.number]], - *, - multicut_as_split: bool = True, - ) -> "GraphEditOperation": - """Generates the correct GraphEditOperation given a log record dictionary. - :param cg: The ChunkedGraph instance - :type cg: "ChunkedGraph" - :param log_record: log record dictionary - :type log_record: Dict[column_keys._Column, Union[np.ndarray, np.number]] - :param multicut_as_split: If true, don't recalculate MultiCutOperation, just - use the resulting removed edges and generate SplitOperation instead (faster). - :type multicut_as_split: bool - - :return: The matching GraphEditOperation subclass - :rtype: "GraphEditOperation" - """ - - def _optional(column): - try: - return log_record[column] - except KeyError: - return None - - log_record_type = cls.get_log_record_type(log_record, multicut_as_split=multicut_as_split) - user_id = log_record[column_keys.OperationLogs.UserID] - - if log_record_type is UndoOperation: - superseded_operation_id = log_record[column_keys.OperationLogs.UndoOperationID] - return cls.undo_operation( - cg, - user_id=user_id, - operation_id=superseded_operation_id, - multicut_as_split=multicut_as_split, - ) - - if log_record_type is RedoOperation: - superseded_operation_id = log_record[column_keys.OperationLogs.RedoOperationID] - return cls.redo_operation( - cg, - user_id=user_id, - operation_id=superseded_operation_id, - multicut_as_split=multicut_as_split, - ) - - source_coords = _optional(column_keys.OperationLogs.SourceCoordinate) - sink_coords = _optional(column_keys.OperationLogs.SinkCoordinate) - - if log_record_type is MergeOperation: - added_edges = log_record[column_keys.OperationLogs.AddedEdge] - affinities = _optional(column_keys.OperationLogs.Affinity) - return MergeOperation( - cg, - user_id=user_id, - source_coords=source_coords, - sink_coords=sink_coords, - added_edges=added_edges, - affinities=affinities, - ) - - if log_record_type is SplitOperation: - removed_edges = log_record[column_keys.OperationLogs.RemovedEdge] - return SplitOperation( - cg, - user_id=user_id, - source_coords=source_coords, - sink_coords=sink_coords, - removed_edges=removed_edges, - ) - - if log_record_type is MulticutOperation: - bbox_offset = log_record[column_keys.OperationLogs.BoundingBoxOffset] - source_ids = log_record[column_keys.OperationLogs.SourceID] - sink_ids = log_record[column_keys.OperationLogs.SinkID] - return MulticutOperation( - cg, - user_id=user_id, - source_coords=source_coords, - sink_coords=sink_coords, - bbox_offset=bbox_offset, - source_ids=source_ids, - sink_ids=sink_ids, - ) - - raise TypeError(f"Could not determine graph operation type.") - - @classmethod - def from_operation_id( - cls, cg: "ChunkedGraph", operation_id: np.uint64, *, multicut_as_split: bool = True - ): - """Generates the correct GraphEditOperation given a operation ID. - :param cg: The ChunkedGraph instance - :type cg: "ChunkedGraph" - :param operation_id: The operation ID - :type operation_id: np.uint64 - :param multicut_as_split: If true, don't recalculate MultiCutOperation, just - use the resulting removed edges and generate SplitOperation instead (faster). - :type multicut_as_split: bool - - :return: The matching GraphEditOperation subclass - :rtype: "GraphEditOperation" - """ - log_record = cg.read_log_row(operation_id) - return cls.from_log_record(cg, log_record, multicut_as_split=multicut_as_split) - - @classmethod - def undo_operation( - cls, - cg: "ChunkedGraph", - *, - user_id: str, - operation_id: np.uint64, - multicut_as_split: bool = True, - ) -> Union["UndoOperation", "RedoOperation"]: - """Create a GraphEditOperation that, if executed, would undo the changes introduced by - operation_id. - - NOTE: If operation_id is an UndoOperation, this function might return an instance of - RedoOperation instead (depending on how the Undo/Redo chain unrolls) - - :param cg: The ChunkedGraph instance - :type cg: "ChunkedGraph" - :param user_id: User that should be associated with this undo operation - :type user_id: str - :param operation_id: The operation ID to be undone - :type operation_id: np.uint64 - :param multicut_as_split: If true, don't recalculate MultiCutOperation, just - use the resulting removed edges and generate SplitOperation instead (faster). - :type multicut_as_split: bool - - :return: A GraphEditOperation that, if executed, will undo the change introduced by - operation_id. - :rtype: Union["UndoOperation", "RedoOperation"] - """ - return cls._resolve_undo_chain( - cg, - user_id=user_id, - operation_id=operation_id, - is_undo=True, - multicut_as_split=multicut_as_split, - ) - - @classmethod - def redo_operation( - cls, cg: "ChunkedGraph", *, user_id: str, operation_id: np.uint64, multicut_as_split=True - ) -> Union["UndoOperation", "RedoOperation"]: - """Create a GraphEditOperation that, if executed, would redo the changes introduced by - operation_id. - - NOTE: If operation_id is an UndoOperation, this function might return an instance of - UndoOperation instead (depending on how the Undo/Redo chain unrolls) - - :param cg: The ChunkedGraph instance - :type cg: "ChunkedGraph" - :param user_id: User that should be associated with this redo operation - :type user_id: str - :param operation_id: The operation ID to be redone - :type operation_id: np.uint64 - :param multicut_as_split: If true, don't recalculate MultiCutOperation, just - use the resulting removed edges and generate SplitOperation instead (faster). - :type multicut_as_split: bool - - :return: A GraphEditOperation that, if executed, will redo the changes introduced by - operation_id. - :rtype: Union["UndoOperation", "RedoOperation"] - """ - return cls._resolve_undo_chain( - cg, - user_id=user_id, - operation_id=operation_id, - is_undo=False, - multicut_as_split=multicut_as_split, - ) - - @abstractmethod - def _update_root_ids(self) -> np.ndarray: - """Retrieves and validates the most recent root IDs affected by this GraphEditOperation. - :return: New most recent root IDs - :rtype: np.ndarray - """ - - @abstractmethod - def _apply( - self, *, operation_id, timestamp - ) -> Tuple[np.ndarray, np.ndarray, List["bigtable.row.Row"]]: - """Initiates the graph operation calculation. - :return: New root IDs, new Lvl2 node IDs, and affected Bigtable rows - :rtype: Tuple[np.ndarray, np.ndarray, List["bigtable.row.Row"]] - """ - - @abstractmethod - def _create_log_record(self, *, operation_id, timestamp, new_root_ids) -> "bigtable.row.Row": - """Creates a log record with all necessary information to replay the current - GraphEditOperation - :return: Bigtable row containing the log record - :rtype: bigtable.row.Row - """ - - @abstractmethod - def invert(self) -> "GraphEditOperation": - """Creates a GraphEditOperation that would cancel out changes introduced by the current - GraphEditOperation - :return: The inverse of GraphEditOperation - :rtype: GraphEditOperation - """ - - def execute(self) -> "GraphEditOperation.Result": - """Executes current GraphEditOperation: - * Calls the subclass's _update_root_ids method - * Locks root IDs - * Calls the subclass's _apply method - * Calls the subclass's _create_log_record method - * Writes all new rows to Bigtable - * Releases root ID lock - :return: Result of successful graph operation - :rtype: GraphEditOperation.Result - """ - root_ids = self._update_root_ids() - - with RootLock(self.cg, root_ids) as root_lock: - lock_operation_ids = np.array([root_lock.operation_id] * len(root_lock.locked_root_ids)) - timestamp = self.cg.read_consolidated_lock_timestamp( - root_lock.locked_root_ids, lock_operation_ids - ) - - new_root_ids, new_lvl2_ids, rows = self._apply( - operation_id=root_lock.operation_id, timestamp=timestamp - ) - - # FIXME: Remove once cg_edits.remove_edges/cg_edits.add_edges return consistent type - new_root_ids = np.array(new_root_ids, dtype=basetypes.NODE_ID) - new_lvl2_ids = np.array(new_lvl2_ids, dtype=basetypes.NODE_ID) - - # Add a row to the log - log_row = self._create_log_record( - operation_id=root_lock.operation_id, new_root_ids=new_root_ids, timestamp=timestamp - ) - - # Put log row first! - rows = [log_row] + rows - - # Execute write (makes sure that we are still owning the lock) - self.cg.bulk_write( - rows, - root_lock.locked_root_ids, - operation_id=root_lock.operation_id, - slow_retry=False, - ) - return GraphEditOperation.Result( - operation_id=root_lock.operation_id, - new_root_ids=new_root_ids, - new_lvl2_ids=new_lvl2_ids, - ) - - -class MergeOperation(GraphEditOperation): - """Merge Operation: Connect *known* pairs of supervoxels by adding a (weighted) edge. - - :param cg: The ChunkedGraph object - :type cg: ChunkedGraph - :param user_id: User ID that will be assigned to this operation - :type user_id: str - :param added_edges: Supervoxel IDs of all added edges [[source, sink]] - :type added_edges: Sequence[Sequence[np.uint64]] - :param source_coords: world space coordinates in nm, corresponding to IDs in added_edges[:,0], defaults to None - :type source_coords: Optional[Sequence[Sequence[np.int]]], optional - :param sink_coords: world space coordinates in nm, corresponding to IDs in added_edges[:,1], defaults to None - :type sink_coords: Optional[Sequence[Sequence[np.int]]], optional - :param affinities: edge weights for newly added edges, entries corresponding to added_edges, defaults to None - :type affinities: Optional[Sequence[np.float32]], optional - """ - - __slots__ = ["added_edges", "affinities"] - - def __init__( - self, - cg: "ChunkedGraph", - *, - user_id: str, - added_edges: Sequence[Sequence[np.uint64]], - source_coords: Optional[Sequence[Sequence[np.int]]] = None, - sink_coords: Optional[Sequence[Sequence[np.int]]] = None, - affinities: Optional[Sequence[np.float32]] = None, - ) -> None: - super().__init__(cg, user_id=user_id, source_coords=source_coords, sink_coords=sink_coords) - self.added_edges = np.atleast_2d(added_edges).astype(basetypes.NODE_ID) - self.affinities = None - - if affinities is not None: - self.affinities = np.atleast_1d(affinities).astype(basetypes.EDGE_AFFINITY) - if self.affinities.size == 0: - self.affinities = None - - if np.any(np.equal(self.added_edges[:, 0], self.added_edges[:, 1])): - raise cg_exceptions.PreconditionError( - f"Requested merge operation contains at least one self-loop." - ) - - for supervoxel_id in self.added_edges.ravel(): - layer = self.cg.get_chunk_layer(supervoxel_id) - if layer != 1: - raise cg_exceptions.PreconditionError( - f"Supervoxel expected, but {supervoxel_id} is a layer {layer} node." - ) - - def _update_root_ids(self) -> np.ndarray: - root_ids = np.unique(self.cg.get_roots(self.added_edges.ravel())) - return root_ids - - def _apply( - self, *, operation_id, timestamp - ) -> Tuple[np.ndarray, np.ndarray, List["bigtable.row.Row"]]: - new_root_ids, new_lvl2_ids, rows = cg_edits.add_edges( - self.cg, - operation_id, - atomic_edges=self.added_edges, - time_stamp=timestamp, - affinities=self.affinities, - ) - return new_root_ids, new_lvl2_ids, rows - - def _create_log_record(self, *, operation_id, timestamp, new_root_ids) -> "bigtable.row.Row": - val_dict = { - column_keys.OperationLogs.UserID: self.user_id, - column_keys.OperationLogs.RootID: new_root_ids, - column_keys.OperationLogs.AddedEdge: self.added_edges, - } - if self.source_coords is not None: - val_dict[column_keys.OperationLogs.SourceCoordinate] = self.source_coords - if self.sink_coords is not None: - val_dict[column_keys.OperationLogs.SinkCoordinate] = self.sink_coords - if self.affinities is not None: - val_dict[column_keys.OperationLogs.Affinity] = self.affinities - - return self.cg.mutate_row(serializers.serialize_uint64(operation_id), val_dict, timestamp) - - def invert(self) -> "SplitOperation": - return SplitOperation( - self.cg, - user_id=self.user_id, - removed_edges=self.added_edges, - source_coords=self.source_coords, - sink_coords=self.sink_coords, - ) - - -class SplitOperation(GraphEditOperation): - """Split Operation: Cut *known* pairs of supervoxel that are directly connected by an edge. - - :param cg: The ChunkedGraph object - :type cg: ChunkedGraph - :param user_id: User ID that will be assigned to this operation - :type user_id: str - :param removed_edges: Supervoxel IDs of all removed edges [[source, sink]] - :type removed_edges: Sequence[Sequence[np.uint64]] - :param source_coords: world space coordinates in nm, corresponding to IDs in - removed_edges[:,0], defaults to None - :type source_coords: Optional[Sequence[Sequence[np.int]]], optional - :param sink_coords: world space coordinates in nm, corresponding to IDs in - removed_edges[:,1], defaults to None - :type sink_coords: Optional[Sequence[Sequence[np.int]]], optional - """ - - __slots__ = ["removed_edges"] - - def __init__( - self, - cg: "ChunkedGraph", - *, - user_id: str, - removed_edges: Sequence[Sequence[np.uint64]], - source_coords: Optional[Sequence[Sequence[np.int]]] = None, - sink_coords: Optional[Sequence[Sequence[np.int]]] = None, - ) -> None: - super().__init__(cg, user_id=user_id, source_coords=source_coords, sink_coords=sink_coords) - self.removed_edges = np.atleast_2d(removed_edges).astype(basetypes.NODE_ID) - - if np.any(np.equal(self.removed_edges[:, 0], self.removed_edges[:, 1])): - raise cg_exceptions.PreconditionError( - f"Requested split operation contains at least one self-loop." - ) - - for supervoxel_id in self.removed_edges.ravel(): - layer = self.cg.get_chunk_layer(supervoxel_id) - if layer != 1: - raise cg_exceptions.PreconditionError( - f"Supervoxel expected, but {supervoxel_id} is a layer {layer} node." - ) - - def _update_root_ids(self) -> np.ndarray: - root_ids = np.unique(self.cg.get_roots(self.removed_edges.ravel())) - if len(root_ids) > 1: - raise cg_exceptions.PreconditionError( - f"All supervoxel must belong to the same object. Already split?" - ) - return root_ids - - def _apply( - self, *, operation_id, timestamp - ) -> Tuple[np.ndarray, np.ndarray, List["bigtable.row.Row"]]: - new_root_ids, new_lvl2_ids, rows = cg_edits.remove_edges( - self.cg, operation_id, atomic_edges=self.removed_edges, time_stamp=timestamp - ) - return new_root_ids, new_lvl2_ids, rows - - def _create_log_record( - self, *, operation_id: np.uint64, timestamp: datetime, new_root_ids: Sequence[np.uint64] - ) -> "bigtable.row.Row": - val_dict = { - column_keys.OperationLogs.UserID: self.user_id, - column_keys.OperationLogs.RootID: new_root_ids, - column_keys.OperationLogs.RemovedEdge: self.removed_edges, - } - if self.source_coords is not None: - val_dict[column_keys.OperationLogs.SourceCoordinate] = self.source_coords - if self.sink_coords is not None: - val_dict[column_keys.OperationLogs.SinkCoordinate] = self.sink_coords - - return self.cg.mutate_row(serializers.serialize_uint64(operation_id), val_dict, timestamp) - - def invert(self) -> "MergeOperation": - return MergeOperation( - self.cg, - user_id=self.user_id, - added_edges=self.removed_edges, - source_coords=self.source_coords, - sink_coords=self.sink_coords, - ) - - -class MulticutOperation(GraphEditOperation): - """ - Multicut Operation: Apply min-cut algorithm to identify suitable edges for removal - in order to separate two groups of supervoxels. - - :param cg: The ChunkedGraph object - :type cg: ChunkedGraph - :param user_id: User ID that will be assigned to this operation - :type user_id: str - :param source_ids: Supervoxel IDs that should be separated from supervoxel IDs in sink_ids - :type souce_ids: Sequence[np.uint64] - :param sink_ids: Supervoxel IDs that should be separated from supervoxel IDs in source_ids - :type sink_ids: Sequence[np.uint64] - :param source_coords: world space coordinates in nm, corresponding to IDs in source_ids - :type source_coords: Sequence[Sequence[np.int]] - :param sink_coords: world space coordinates in nm, corresponding to IDs in sink_ids - :type sink_coords: Sequence[Sequence[np.int]] - :param bbox_offset: Padding for min-cut bounding box, applied to min/max coordinates - retrieved from source_coords and sink_coords, defaults to None - :type bbox_offset: Sequence[np.int] - """ - - __slots__ = ["source_ids", "sink_ids", "removed_edges", "bbox_offset"] - - def __init__( - self, - cg: "ChunkedGraph", - *, - user_id: str, - source_ids: Sequence[np.uint64], - sink_ids: Sequence[np.uint64], - source_coords: Sequence[Sequence[np.int]], - sink_coords: Sequence[Sequence[np.int]], - bbox_offset: Sequence[np.int], - ) -> None: - super().__init__(cg, user_id=user_id, source_coords=source_coords, sink_coords=sink_coords) - self.removed_edges = None # Calculated from coordinates and IDs - self.source_ids = np.atleast_1d(source_ids).astype(basetypes.NODE_ID) - self.sink_ids = np.atleast_1d(sink_ids).astype(basetypes.NODE_ID) - self.bbox_offset = np.atleast_1d(bbox_offset).astype(basetypes.COORDINATES) - - if np.any(np.in1d(self.sink_ids, self.source_ids)): - raise cg_exceptions.PreconditionError( - f"One or more supervoxel exists as both, sink and source." - ) - - for supervoxel_id in itertools.chain(self.source_ids, self.sink_ids): - layer = self.cg.get_chunk_layer(supervoxel_id) - if layer != 1: - raise cg_exceptions.PreconditionError( - f"Supervoxel expected, but {supervoxel_id} is a layer {layer} node." - ) - - def _update_root_ids(self) -> np.ndarray: - sink_and_source_ids = np.concatenate((self.source_ids, self.sink_ids)) - root_ids = np.unique(self.cg.get_roots(sink_and_source_ids)) - if len(root_ids) > 1: - raise cg_exceptions.PreconditionError( - f"All supervoxel must belong to the same object. Already split?" - ) - return root_ids - - def _apply( - self, *, operation_id, timestamp - ) -> Tuple[np.ndarray, np.ndarray, List["bigtable.row.Row"]]: - self.removed_edges = self.cg._run_multicut( - self.source_ids, self.sink_ids, self.source_coords, self.sink_coords, self.bbox_offset - ) - - if self.removed_edges.size == 0: - raise cg_exceptions.PostconditionError( - "Mincut could not find any edges to remove - weird!" - ) - - new_root_ids, new_lvl2_ids, rows = cg_edits.remove_edges( - self.cg, operation_id, atomic_edges=self.removed_edges, time_stamp=timestamp - ) - return new_root_ids, new_lvl2_ids, rows - - def _create_log_record( - self, *, operation_id: np.uint64, timestamp: datetime, new_root_ids: Sequence[np.uint64] - ) -> "bigtable.row.Row": - val_dict = { - column_keys.OperationLogs.UserID: self.user_id, - column_keys.OperationLogs.RootID: new_root_ids, - column_keys.OperationLogs.SourceCoordinate: self.source_coords, - column_keys.OperationLogs.SinkCoordinate: self.sink_coords, - column_keys.OperationLogs.SourceID: self.source_ids, - column_keys.OperationLogs.SinkID: self.sink_ids, - column_keys.OperationLogs.BoundingBoxOffset: self.bbox_offset, - } - return self.cg.mutate_row(serializers.serialize_uint64(operation_id), val_dict, timestamp) - - def invert(self) -> "MergeOperation": - return MergeOperation( - self.cg, - user_id=self.user_id, - added_edges=self.removed_edges, - source_coords=self.source_coords, - sink_coords=self.sink_coords, - ) - - -class RedoOperation(GraphEditOperation): - """ - RedoOperation: Used to apply a previous graph edit operation. In contrast to a - "coincidental" redo (e.g. merging an edge added by a previous merge operation), a - RedoOperation is linked to an earlier operation ID to enable its correct repetition. - Acts as counterpart to UndoOperation. - - NOTE: Avoid instantiating a RedoOperation directly, if possible. The class method - GraphEditOperation.redo_operation() is in general preferred as it will correctly - unroll Undo/Redo chains. - - :param cg: The ChunkedGraph object - :type cg: ChunkedGraph - :param user_id: User ID that will be assigned to this operation - :type user_id: str - :param superseded_operation_id: Operation ID to be redone - :type superseded_operation_id: np.uint64 - :param multicut_as_split: If true, don't recalculate MultiCutOperation, just - use the resulting removed edges and generate SplitOperation instead (faster). - :type multicut_as_split: bool - """ - - __slots__ = ["superseded_operation_id", "superseded_operation"] - - def __init__( - self, - cg: "ChunkedGraph", - *, - user_id: str, - superseded_operation_id: np.uint64, - multicut_as_split: bool, - ) -> None: - super().__init__(cg, user_id=user_id) - log_record = cg.read_log_row(superseded_operation_id) - log_record_type = GraphEditOperation.get_log_record_type(log_record) - if log_record_type in (RedoOperation, UndoOperation): - raise ValueError( - ( - f"RedoOperation received {log_record_type.__name__} as target operation, " - "which is not allowed. Use GraphEditOperation.create_redo() instead." - ) - ) - - self.superseded_operation_id = superseded_operation_id - self.superseded_operation = GraphEditOperation.from_log_record( - cg, log_record=log_record, multicut_as_split=multicut_as_split - ) - - def _update_root_ids(self): - return self.superseded_operation._update_root_ids() - - def _apply( - self, *, operation_id, timestamp - ) -> Tuple[np.ndarray, np.ndarray, List["bigtable.row.Row"]]: - return self.superseded_operation._apply(operation_id=operation_id, timestamp=timestamp) - - def _create_log_record( - self, *, operation_id: np.uint64, timestamp: datetime, new_root_ids: Sequence[np.uint64] - ) -> "bigtable.row.Row": - val_dict = { - column_keys.OperationLogs.UserID: self.user_id, - column_keys.OperationLogs.RedoOperationID: self.superseded_operation_id, - column_keys.OperationLogs.RootID: new_root_ids, - } - return self.cg.mutate_row(serializers.serialize_uint64(operation_id), val_dict, timestamp) - - def invert(self) -> "GraphEditOperation": - """ - Inverts a RedoOperation. Treated as Undoing the original operation - """ - return UndoOperation( - self.cg, - user_id=self.user_id, - superseded_operation_id=self.superseded_operation_id, - multicut_as_split=False, - ) - - -class UndoOperation(GraphEditOperation): - """ - UndoOperation: Used to apply the inverse of a previous graph edit operation. In contrast - to a "coincidental" undo (e.g. merging an edge previously removed by a split operation), an - UndoOperation is linked to an earlier operation ID to enable its correct reversal. - - NOTE: Avoid instantiating an UndoOperation directly, if possible. The class method - GraphEditOperation.undo_operation() is in general preferred as it will correctly - unroll Undo/Redo chains. - - :param cg: The ChunkedGraph object - :type cg: ChunkedGraph - :param user_id: User ID that will be assigned to this operation - :type user_id: str - :param superseded_operation_id: Operation ID to be undone - :type superseded_operation_id: np.uint64 - :param multicut_as_split: If true, don't recalculate MultiCutOperation, just - use the resulting removed edges and generate SplitOperation instead (faster). - :type multicut_as_split: bool - """ - - __slots__ = ["superseded_operation_id", "inverse_superseded_operation"] - - def __init__( - self, - cg: "ChunkedGraph", - *, - user_id: str, - superseded_operation_id: np.uint64, - multicut_as_split: bool, - ) -> None: - super().__init__(cg, user_id=user_id) - log_record = cg.read_log_row(superseded_operation_id) - log_record_type = GraphEditOperation.get_log_record_type(log_record) - if log_record_type in (RedoOperation, UndoOperation): - raise ValueError( - ( - f"UndoOperation received {log_record_type.__name__} as target operation, " - "which is not allowed. Use GraphEditOperation.create_undo() instead." - ) - ) - - self.superseded_operation_id = superseded_operation_id - self.inverse_superseded_operation = GraphEditOperation.from_log_record( - cg, log_record=log_record, multicut_as_split=multicut_as_split - ).invert() - - def _update_root_ids(self): - return self.inverse_superseded_operation._update_root_ids() - - def _apply( - self, *, operation_id, timestamp - ) -> Tuple[np.ndarray, np.ndarray, List["bigtable.row.Row"]]: - return self.inverse_superseded_operation._apply( - operation_id=operation_id, timestamp=timestamp - ) - - def _create_log_record( - self, *, operation_id: np.uint64, timestamp: datetime, new_root_ids: Sequence[np.uint64] - ) -> "bigtable.row.Row": - val_dict = { - column_keys.OperationLogs.UserID: self.user_id, - column_keys.OperationLogs.UndoOperationID: self.superseded_operation_id, - column_keys.OperationLogs.RootID: new_root_ids, - } - return self.cg.mutate_row(serializers.serialize_uint64(operation_id), val_dict, timestamp) - - def invert(self) -> "GraphEditOperation": - """ - Inverts an UndoOperation. Treated as Redoing the original operation - """ - return RedoOperation( - self.cg, - user_id=self.user_id, - superseded_operation_id=self.superseded_operation_id, - multicut_as_split=False, - ) diff --git a/pychunkedgraph/backend/requirements.txt b/pychunkedgraph/backend/requirements.txt deleted file mode 100644 index ba50f47cf..000000000 --- a/pychunkedgraph/backend/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -google-cloud -cloud-volume -networkx \ No newline at end of file diff --git a/pychunkedgraph/backend/root_lock.py b/pychunkedgraph/backend/root_lock.py deleted file mode 100644 index 52a2dca9d..000000000 --- a/pychunkedgraph/backend/root_lock.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import TYPE_CHECKING, Sequence, Union - -import numpy as np - -from pychunkedgraph.backend import chunkedgraph_exceptions as cg_exceptions - -if TYPE_CHECKING: - from pychunkedgraph.backend.chunkedgraph import ChunkedGraph - - -class RootLock: - """Attempts to lock the requested root IDs using a unique operation ID. - - :raises cg_exceptions.LockingError: throws when one or more root ID locks could not be - acquired. - :return: The RootLock context, including the locked root IDs and the linked operation ID - :rtype: RootLock - """ - __slots__ = ["cg", "locked_root_ids", "lock_acquired", "operation_id"] - # FIXME: `locked_root_ids` is only required and exposed because `cg.lock_root_loop` - # currently might lock different (more recent) root IDs than requested. - - def __init__(self, cg: "ChunkedGraph", root_ids: Union[np.uint64, Sequence[np.uint64]]) -> None: - self.cg = cg - self.locked_root_ids = np.atleast_1d(root_ids) - self.lock_acquired = False - self.operation_id = None - - def __enter__(self): - self.operation_id = self.cg.get_unique_operation_id() - self.lock_acquired, self.locked_root_ids = self.cg.lock_root_loop( - root_ids=self.locked_root_ids, operation_id=self.operation_id, max_tries=7 - ) - if not self.lock_acquired: - raise cg_exceptions.LockingError("Could not acquire root lock") - return self - - def __exit__(self, exception_type, exception_value, traceback): - if self.lock_acquired: - for locked_root_id in self.locked_root_ids: - self.cg.unlock_root(locked_root_id, self.operation_id) diff --git a/pychunkedgraph/backend/utils/column_keys.py b/pychunkedgraph/backend/utils/column_keys.py deleted file mode 100644 index 2f952d4a0..000000000 --- a/pychunkedgraph/backend/utils/column_keys.py +++ /dev/null @@ -1,255 +0,0 @@ -from typing import NamedTuple -from pychunkedgraph.backend.utils import basetypes, serializers - - -class _ColumnType(NamedTuple): - key: bytes - family_id: str - serializer: serializers._Serializer - - -class _Column(_ColumnType): - __slots__ = () - _columns = {} - - def __init__(self, **kwargs): - super().__init__() - _Column._columns[(kwargs['family_id'], kwargs['key'])] = self - - def serialize(self, obj): - return self.serializer.serialize(obj) - - def deserialize(self, stream): - return self.serializer.deserialize(stream) - - @property - def basetype(self): - return self.serializer.basetype - - -class _ColumnArray(): - _columnarrays = {} - - def __init__(self, pattern, family_id, serializer): - self._pattern = pattern - self._family_id = family_id - self._serializer = serializer - _ColumnArray._columnarrays[(family_id, pattern)] = self - - # TODO: Add missing check in `fromkey(family_id, key)` and remove this - # loop (pre-creates `_Columns`, so that the inverse lookup works) - for i in range(20): - self[i] # pylint: disable=W0104 - - def __getitem__(self, item): - return _Column(key=self.pattern % item, - family_id=self.family_id, - serializer=self._serializer) - - @property - def pattern(self): - return self._pattern - - @property - def family_id(self): - return self._family_id - - @property - def serialize(self): - return self._serializer.serialize - - @property - def deserialize(self): - return self._serializer.deserialize - - @property - def basetype(self): - return self._serializer.basetype - - -class Concurrency: - CounterID = _Column( - key=b'counter', - family_id='1', - serializer=serializers.NumPyValue(dtype=basetypes.COUNTER)) - - Lock = _Column( - key=b'lock', - family_id='0', - serializer=serializers.UInt64String()) - - -class Connectivity: - Affinity = _Column( - key=b'affinities', - family_id='0', - serializer=serializers.NumPyArray(dtype=basetypes.EDGE_AFFINITY)) - - Area = _Column( - key=b'areas', - family_id='0', - serializer=serializers.NumPyArray(dtype=basetypes.EDGE_AREA)) - - Connected = _Column( - key=b'connected', - family_id='0', - serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID)) - - Disconnected = _Column( - key=b'disconnected', - family_id='0', - serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID)) - - Partner = _Column( - key=b'atomic_partners', - family_id='0', - serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID)) - - CrossChunkEdge = _ColumnArray( - pattern=b'atomic_cross_edges_%d', - family_id='3', - serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID, shape=(-1, 2))) - - -class Hierarchy: - Child = _Column( - key=b'children', - family_id='0', - serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID)) - - FormerParent = _Column( - key=b'former_parents', - family_id='0', - serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID)) - - NewParent = _Column( - key=b'new_parents', - family_id='0', - serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID)) - - Parent = _Column( - key=b'parents', - family_id='0', - serializer=serializers.NumPyValue(dtype=basetypes.NODE_ID)) - - -class GraphSettings: - DatasetInfo = _Column( - key=b'dataset_info', - family_id='0', - serializer=serializers.JSON()) - - ChunkSize = _Column( - key=b'chunk_size', - family_id='0', - serializer=serializers.NumPyArray(dtype=basetypes.CHUNKSIZE)) - - FanOut = _Column( - key=b'fan_out', - family_id='0', - serializer=serializers.NumPyValue(dtype=basetypes.FANOUT)) - - LayerCount = _Column( - key=b'n_layers', - family_id='0', - serializer=serializers.NumPyValue(dtype=basetypes.LAYERCOUNT)) - - SegmentationPath = _Column( - key=b'cv_path', - family_id='0', - serializer=serializers.String('utf-8')) - - MeshDir = _Column( - key=b'mesh_dir', - family_id='0', - serializer=serializers.String('utf-8')) - - SpatialBits = _Column( - key=b'spatial_bits', - family_id='0', - serializer=serializers.NumPyValue(dtype=basetypes.SPATIALBITS)) - - RootCounterBits = _Column( - key=b'root_counter_bits', - family_id='0', - serializer=serializers.NumPyValue(dtype=basetypes.ROOTCOUNTERBITS)) - - SkipConnections = _Column( - key=b'skip_connections', - family_id='0', - serializer=serializers.NumPyValue(dtype=basetypes.SKIPCONNECTIONS)) - -class OperationLogs: - OperationID = _Column( - key=b'operation_id', - family_id='0', - serializer=serializers.UInt64String()) - - UndoOperationID = _Column( - key=b'undo_operation_id', - family_id='2', - serializer=serializers.UInt64String()) - - RedoOperationID = _Column( - key=b'redo_operation_id', - family_id='2', - serializer=serializers.UInt64String()) - - UserID = _Column( - key=b'user', - family_id='2', - serializer=serializers.String('utf-8')) - - RootID = _Column( - key=b'roots', - family_id='2', - serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID)) - - SourceID = _Column( - key=b'source_ids', - family_id='2', - serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID)) - - SinkID = _Column( - key=b'sink_ids', - family_id='2', - serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID)) - - SourceCoordinate = _Column( - key=b'source_coords', - family_id='2', - serializer=serializers.NumPyArray(dtype=basetypes.COORDINATES, shape=(-1, 3))) - - SinkCoordinate = _Column( - key=b'sink_coords', - family_id='2', - serializer=serializers.NumPyArray(dtype=basetypes.COORDINATES, shape=(-1, 3))) - - BoundingBoxOffset = _Column( - key=b'bb_offset', - family_id='2', - serializer=serializers.NumPyArray(dtype=basetypes.COORDINATES)) - - AddedEdge = _Column( - key=b'added_edges', - family_id='2', - serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID, shape=(-1, 2))) - - RemovedEdge = _Column( - key=b'removed_edges', - family_id='2', - serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID, shape=(-1, 2))) - - Affinity = _Column( - key=b'affinities', - family_id='2', - serializer=serializers.NumPyArray(dtype=basetypes.EDGE_AFFINITY)) - - -def from_key(family_id: str, key: bytes): - try: - return _Column._columns[(family_id, key)] - except KeyError: - # FIXME: Look if the key matches a columnarray pattern and - # remove loop initialization in _ColumnArray.__init__() - raise KeyError(f"Unknown key {family_id}:{key.decode()}") diff --git a/pychunkedgraph/backend/utils/row_keys.py b/pychunkedgraph/backend/utils/row_keys.py deleted file mode 100644 index d8b83b14f..000000000 --- a/pychunkedgraph/backend/utils/row_keys.py +++ /dev/null @@ -1,2 +0,0 @@ -GraphSettings = b'params' -OperationID = b'ioperations' diff --git a/pychunkedgraph/benchmarking/datasets.py b/pychunkedgraph/benchmarking/datasets.py deleted file mode 100644 index 5be0d36a0..000000000 --- a/pychunkedgraph/benchmarking/datasets.py +++ /dev/null @@ -1,64 +0,0 @@ -import numpy as np - - -from pychunkedgraph.ingest import ran_ingestion as ri -from pychunkedgraph.benchmarking import graph_measurements as gm, timings - - -def create_benchmark_datasets(storage_path, - ws_cv_path, - cg_table_base_id, - chunk_size=[256, 256, 512], - use_skip_connections=True, - s_bits_atomic_layer=None, - fan_out=2, - aff_dtype=np.float32, - start_iter=0, - n_iter=8, - run_gm=False, - run_bm=False, - job_size=250, - instance_id=None, - project_id=None, - n_threads=[64, 64]): - - chunk_size = np.array(chunk_size) - - for i_iter in range(start_iter, n_iter): - size = chunk_size * 2 ** (i_iter + 1) - cg_table_id = f"{cg_table_base_id}_s{i_iter}" - - ri.ingest_into_chunkedgraph(storage_path=storage_path, - ws_cv_path=ws_cv_path, - cg_table_id=cg_table_id, - chunk_size=chunk_size, - use_skip_connections=use_skip_connections, - s_bits_atomic_layer=s_bits_atomic_layer, - fan_out=fan_out, - aff_dtype=aff_dtype, - size=size, - instance_id=instance_id, - project_id=project_id, - start_layer=1, - n_threads=n_threads) - - if run_gm or run_bm: - gm.run_graph_measurements(table_id=cg_table_id, n_threads=n_threads[0]) - - if run_bm: - timings.run_timings(table_id=cg_table_id, job_size=job_size) - - -def compute_graph_measurements_dataset(cg_table_base_id, start_iter=0, n_iter=8, - n_threads=1): - for i_iter in range(start_iter, n_iter): - cg_table_id = f"{cg_table_base_id}_s{i_iter}" - gm.run_graph_measurements(table_id=cg_table_id, n_threads=n_threads) - - -def compute_benchmarks_dataset(cg_table_base_id, start_iter=0, n_iter=8, - job_size=500): - for i_iter in range(start_iter, n_iter): - cg_table_id = f"{cg_table_base_id}_s{i_iter}" - timings.run_timings(table_id=cg_table_id, job_size=job_size) - diff --git a/pychunkedgraph/benchmarking/graph_measurements.py b/pychunkedgraph/benchmarking/graph_measurements.py deleted file mode 100644 index a526c5012..000000000 --- a/pychunkedgraph/benchmarking/graph_measurements.py +++ /dev/null @@ -1,375 +0,0 @@ -import numpy as np -import itertools -import time -import os -import h5py -import pandas as pd - -from pychunkedgraph.backend import chunkedgraph, chunkedgraph_comp -from pychunkedgraph.backend.utils import column_keys - -from multiwrapper import multiprocessing_utils as mu - - -HOME = os.path.expanduser("~") - - -def count_nodes_and_edges(table_id, n_threads=1): - cg = chunkedgraph.ChunkedGraph(table_id) - - bounds = np.array(cg.cv.bounds.to_list()).reshape(2, -1).T - bounds -= bounds[:, 0:1] - - chunk_id_bounds = np.ceil((bounds / cg.chunk_size[:, None])).astype(np.int) - - chunk_coord_gen = itertools.product(*[range(*r) for r in chunk_id_bounds]) - chunk_coords = np.array(list(chunk_coord_gen), dtype=np.int) - - order = np.arange(len(chunk_coords)) - np.random.shuffle(order) - - n_blocks = np.min([len(order), n_threads * 3]) - blocks = np.array_split(order, n_blocks) - - cg_serialized_info = cg.get_serialized_info() - if n_threads > 1: - del cg_serialized_info["credentials"] - - multi_args = [] - for block in blocks: - multi_args.append([cg_serialized_info, chunk_coords[block]]) - - if n_threads == 1: - results = mu.multiprocess_func(_count_nodes_and_edges, - multi_args, n_threads=n_threads, - verbose=False, debug=n_threads == 1) - else: - results = mu.multisubprocess_func(_count_nodes_and_edges, - multi_args, n_threads=n_threads) - - n_edges_per_chunk = [] - n_nodes_per_chunk = [] - for result in results: - n_nodes_per_chunk.extend(result[0]) - n_edges_per_chunk.extend(result[1]) - - return n_nodes_per_chunk, n_edges_per_chunk - - -def _count_nodes_and_edges(args): - serialized_cg_info, chunk_coords = args - - time_start = time.time() - - cg = chunkedgraph.ChunkedGraph(**serialized_cg_info) - - n_edges_per_chunk = [] - n_nodes_per_chunk = [] - for chunk_coord in chunk_coords: - x, y, z = chunk_coord - rr = cg.range_read_chunk(layer=1, x=x, y=y, z=z) - - n_nodes_per_chunk.append(len(rr)) - n_edges = 0 - - for k in rr.keys(): - n_edges += len(rr[k][column_keys.Connectivity.Partner][0].value) - - n_edges_per_chunk.append(n_edges) - - print(f"{len(chunk_coords)} took {time.time() - time_start}s") - return n_nodes_per_chunk, n_edges_per_chunk - - -def count_and_download_nodes(table_id, save_dir=f"{HOME}/benchmarks/", - n_threads=1): - cg = chunkedgraph.ChunkedGraph(table_id) - - bounds = np.array(cg.cv.bounds.to_list()).reshape(2, -1).T - bounds -= bounds[:, 0:1] - - chunk_id_bounds = np.ceil((bounds / cg.chunk_size[:, None])).astype(np.int) - - chunk_coord_gen = itertools.product(*[range(*r) for r in chunk_id_bounds]) - chunk_coords = np.array(list(chunk_coord_gen), dtype=np.int) - - order = np.arange(len(chunk_coords)) - np.random.shuffle(order) - - n_blocks = np.min([len(order), n_threads * 3]) - blocks = np.array_split(order, n_blocks) - - cg_serialized_info = cg.get_serialized_info() - if n_threads > 1: - del cg_serialized_info["credentials"] - - multi_args = [] - for block in blocks: - multi_args.append([cg_serialized_info, chunk_coords[block]]) - - if n_threads == 1: - results = mu.multiprocess_func(_count_and_download_nodes, - multi_args, n_threads=n_threads, - verbose=False, debug=n_threads == 1) - else: - results = mu.multisubprocess_func(_count_and_download_nodes, - multi_args, n_threads=n_threads) - - n_nodes_per_l2_node = [] - n_l2_nodes_per_chunk = [] - n_l1_nodes_per_chunk = [] - rep_l1_nodes = [] - for result in results: - n_nodes_per_l2_node.extend(result[0]) - n_l2_nodes_per_chunk.extend(result[1]) - n_l1_nodes_per_chunk.extend(result[2]) - rep_l1_nodes.extend(result[3]) - - save_folder = f"{save_dir}/{table_id}/" - - if not os.path.exists(save_folder): - os.makedirs(save_folder) - - with h5py.File(f"{save_folder}/l1_l2_stats.h5", "w") as f: - f.create_dataset("n_nodes_per_l2_node", data=n_nodes_per_l2_node, - compression="gzip") - f.create_dataset("n_l2_nodes_per_chunk", data=n_l2_nodes_per_chunk, - compression="gzip") - f.create_dataset("n_l1_nodes_per_chunk", data=n_l1_nodes_per_chunk, - compression="gzip") - f.create_dataset("rep_l1_nodes", data=rep_l1_nodes, - compression="gzip") - - -def _count_and_download_nodes(args): - serialized_cg_info, chunk_coords = args - - time_start = time.time() - - cg = chunkedgraph.ChunkedGraph(**serialized_cg_info) - - n_nodes_per_l2_node = [] - n_l2_nodes_per_chunk = [] - n_l1_nodes_per_chunk = [] - # l1_nodes = [] - rep_l1_nodes = [] - for chunk_coord in chunk_coords: - x, y, z = chunk_coord - rr = cg.range_read_chunk(layer=2, x=x, y=y, z=z, - columns=[column_keys.Hierarchy.Child]) - - n_l2_nodes_per_chunk.append(len(rr)) - n_l1_nodes = 0 - - for k in rr.keys(): - children = rr[k][column_keys.Hierarchy.Child][0].value - rep_l1_nodes.append(children[np.random.randint(0, len(children))]) - # l1_nodes.extend(children) - - n_nodes_per_l2_node.append(len(children)) - n_l1_nodes += len(children) - - n_l1_nodes_per_chunk.append(n_l1_nodes) - - print(f"{len(chunk_coords)} took {time.time() - time_start}s") - return n_nodes_per_l2_node, n_l2_nodes_per_chunk, n_l1_nodes_per_chunk, rep_l1_nodes - - -def get_root_ids_and_sv_chunks(table_id, save_dir=f"{HOME}/benchmarks/", - n_threads=1): - cg = chunkedgraph.ChunkedGraph(table_id) - - save_folder = f"{save_dir}/{table_id}/" - - if not os.path.exists(save_folder): - os.makedirs(save_folder) - - if not os.path.exists(f"{save_folder}/root_ids.h5"): - root_ids = chunkedgraph_comp.get_latest_roots(cg, n_threads=n_threads) - - with h5py.File(f"{save_folder}/root_ids.h5", "w") as f: - f.create_dataset("root_ids", data=root_ids) - else: - with h5py.File(f"{save_folder}/root_ids.h5", "r") as f: - root_ids = f["root_ids"].value - - cg_serialized_info = cg.get_serialized_info() - if n_threads > 1: - del cg_serialized_info["credentials"] - - order = np.arange(len(root_ids)) - np.random.shuffle(order) - - order = order - - n_blocks = np.min([len(order), n_threads * 3]) - blocks = np.array_split(order, n_blocks) - - multi_args = [] - for block in blocks: - multi_args.append([cg_serialized_info, root_ids[block]]) - - if n_threads == 1: - results = mu.multiprocess_func(_get_root_ids_and_sv_chunks, - multi_args, n_threads=n_threads, - verbose=False, debug=n_threads == 1) - else: - results = mu.multisubprocess_func(_get_root_ids_and_sv_chunks, - multi_args, n_threads=n_threads) - - root_ids = [] - n_l1_nodes_per_root = [] - rep_l1_nodes = [] - rep_l1_chunk_ids = [] - for result in results: - root_ids.extend(result[0]) - n_l1_nodes_per_root.extend(result[1]) - rep_l1_nodes.extend(result[2]) - rep_l1_chunk_ids.extend(result[3]) - - save_folder = f"{save_dir}/{table_id}/" - - if not os.path.exists(save_folder): - os.makedirs(save_folder) - - with h5py.File(f"{save_folder}/root_stats.h5", "w") as f: - f.create_dataset("root_ids", data=root_ids, - compression="gzip") - f.create_dataset("n_l1_nodes_per_root", data=n_l1_nodes_per_root, - compression="gzip") - f.create_dataset("rep_l1_nodes", data=rep_l1_nodes, - compression="gzip") - f.create_dataset("rep_l1_chunk_ids", data=rep_l1_chunk_ids, - compression="gzip") - - -def _get_root_ids_and_sv_chunks(args): - serialized_cg_info, root_ids = args - - time_start = time.time() - - cg = chunkedgraph.ChunkedGraph(**serialized_cg_info) - - n_l1_nodes_per_root = [] - rep_l1_nodes = [] - rep_l1_chunk_ids = [] - for root_id in root_ids: - l1_ids = cg.get_subgraph_nodes(root_id) - - n_l1_nodes_per_root.append(len(l1_ids)) - rep_l1_node = l1_ids[np.random.randint(0, len(l1_ids))] - rep_l1_nodes.append(rep_l1_node) - rep_l1_chunk_ids.append(cg.get_chunk_coordinates(rep_l1_node)) - - print(f"{len(root_ids)} took {time.time() - time_start}s") - return root_ids, n_l1_nodes_per_root, rep_l1_nodes, rep_l1_chunk_ids - - -def get_merge_candidates(table_id, save_dir=f"{HOME}/benchmarks/", - n_threads=1): - cg = chunkedgraph.ChunkedGraph(table_id) - - bounds = np.array(cg.cv.bounds.to_list()).reshape(2, -1).T - bounds -= bounds[:, 0:1] - - chunk_id_bounds = np.ceil((bounds / cg.chunk_size[:, None])).astype(np.int) - - chunk_coord_gen = itertools.product(*[range(*r) for r in chunk_id_bounds]) - chunk_coords = np.array(list(chunk_coord_gen), dtype=np.int) - - order = np.arange(len(chunk_coords)) - np.random.shuffle(order) - - n_blocks = np.min([len(order), n_threads * 3]) - blocks = np.array_split(order, n_blocks) - - cg_serialized_info = cg.get_serialized_info() - if n_threads > 1: - del cg_serialized_info["credentials"] - - multi_args = [] - for block in blocks: - multi_args.append([cg_serialized_info, chunk_coords[block]]) - - if n_threads == 1: - results = mu.multiprocess_func(_get_merge_candidates, - multi_args, n_threads=n_threads, - verbose=False, debug=n_threads == 1) - else: - results = mu.multisubprocess_func(_get_merge_candidates, - multi_args, n_threads=n_threads) - merge_edges = [] - merge_edge_weights = [] - for result in results: - merge_edges.extend(result[0]) - merge_edge_weights.extend(result[1]) - - save_folder = f"{save_dir}/{table_id}/" - - if not os.path.exists(save_folder): - os.makedirs(save_folder) - - with h5py.File(f"{save_folder}/merge_edge_stats.h5", "w") as f: - f.create_dataset("merge_edges", data=merge_edges, - compression="gzip") - f.create_dataset("merge_edge_weights", data=merge_edge_weights, - compression="gzip") - - -def _get_merge_candidates(args): - serialized_cg_info, chunk_coords = args - - time_start = time.time() - - cg = chunkedgraph.ChunkedGraph(**serialized_cg_info) - - merge_edges = [] - merge_edge_weights = [] - for chunk_coord in chunk_coords: - chunk_id = cg.get_chunk_id(layer=1, x=chunk_coord[0], - y=chunk_coord[1], z=chunk_coord[2]) - - rr = cg.range_read_chunk(chunk_id=chunk_id, - columns=[column_keys.Connectivity.Partner, - column_keys.Connectivity.Connected, - column_keys.Hierarchy.Parent]) - - ps = [] - edges = [] - for it in rr.items(): - e, _, _ = cg._retrieve_connectivity(it, connected_edges=False) - edges.extend(e) - ps.extend([it[1][column_keys.Hierarchy.Parent][0].value] * len(e)) - - if len(edges) == 0: - continue - - edges = np.sort(np.array(edges), axis=1) - cols = {"sv1": edges[:, 0], "sv2": edges[:, 1], "parent": ps} - - df = pd.DataFrame(data=cols) - dfg = df.groupby(["sv1", "sv2"]).aggregate(np.sum).reset_index() - - _, i, c = np.unique(dfg[["parent"]], return_counts=True, - return_index=True) - - merge_edges.extend(np.array(dfg.loc[i][["sv1", "sv2"]], - dtype=np.uint64)) - merge_edge_weights.extend(c) - - - print(f"{len(chunk_coords)} took {time.time() - time_start}s") - - return merge_edges, merge_edge_weights - - - -def run_graph_measurements(table_id, save_dir=f"{HOME}/benchmarks/", - n_threads=1): - get_root_ids_and_sv_chunks(table_id=table_id, save_dir=save_dir, - n_threads=n_threads) - count_and_download_nodes(table_id=table_id, save_dir=save_dir, - n_threads=n_threads) - get_merge_candidates(table_id=table_id, save_dir=save_dir, - n_threads=n_threads) - diff --git a/pychunkedgraph/benchmarking/simulator.py b/pychunkedgraph/benchmarking/simulator.py deleted file mode 100644 index 876259f32..000000000 --- a/pychunkedgraph/benchmarking/simulator.py +++ /dev/null @@ -1,162 +0,0 @@ -import os -import numpy as np -import time -import sys -import threading - -# Hack the imports for now -sys.path.append("..") -import pychunkedgraph.chunkedgraph as chunkedgraph - -HOME = os.path.expanduser("~") - - -def measure_command(func, kwargs): - """ Measures the execution time of a function - - :param func: function - :param kwargs: dict - keyword arguments - :return: float, result - (time, result of fucntion) - """ - time_start = time.time() - r = func(**kwargs) - dt = time.time() - time_start - - return dt, r - - -class ChunkedGraphSimulator(object): - def __init__(self, table_id, dir=HOME + '/benchmarking/', n_clients=10): - self.cg = chunkedgraph.ChunkedGraph(table_id=table_id) - - if not dir.strip("/").endswith(table_id): - dir += "/%s/" % table_id - self._dir = dir - self._n_clients = n_clients - - if not os.path.exists(dir): - os.makedirs(dir) - - @property - def n_clients(self): - return self._n_clients - - @property - def dir(self): - return self._dir - - def read_all_rows(self, layer=None): - """ Reads all rows of the table or specific layer - - :param layer: int or None - :return: dict - """ - if layer is None: - start_layer = 1 - end_layer = 1000 - else: - start_layer = layer - end_layer = layer + 1 - - node_id_base = np.array([0, 0, 0, 0, 0, 0, 0, start_layer], - dtype=np.uint8) - node_id_base_next = node_id_base.copy() - node_id_base_next[-1] = end_layer - - start_key = chunkedgraph.serialize_node_id(np.frombuffer(node_id_base, - dtype=np.uint64)[0]) - end_key = chunkedgraph.serialize_node_id(np.frombuffer(node_id_base_next, - dtype=np.uint64)[0]) - - range_read = self.cg.table.read_rows(start_key=start_key, - end_key=end_key, - end_inclusive=False) - range_read.consume_all() - - return range_read - - def run_root_benchmark(self, n_tests=5000): - """ Measures time of get_root command - - :param n_tests: int or None - number of tests - :return: list of floats - times - """ - atomic_rows = self.read_all_rows(layer=1) - atomic_ids = list(atomic_rows.rows.keys()) - times = [] - - if n_tests is None: - n_tests = len(atomic_ids) - else: - n_tests = np.min([n_tests, len(atomic_ids)]) - - np.random.shuffle(atomic_ids) - - for atomic_id in atomic_ids[: n_tests]: - dt, _ = measure_command(self.cg.get_root, - {"atomic_id": int(atomic_id), - "is_cg_id": True}) - - times.append(dt) - - if len(times) % 100 == 0: - print("%d / %d - %.3fms +- %.3fms" % (len(times), - n_tests, - np.mean(times) * 1000, - np.std(times) * 1000)) - - np.save(self.dir + "/root_single_%d.npy" % (n_tests), times) - return times - - def run_all_leaves_benchmark(self, n_tests=10000, get_edges=False): - """ Measures time of get_root command - - :param n_tests: int or None - number of tests - :param get_edges: bool - :return: list of floats - times - """ - root_rows = self.read_all_rows(layer=6) - root_ids = list(root_rows.rows.keys()) - times = [] - - if n_tests is None: - n_tests = len(root_ids) - else: - n_tests = np.min([n_tests, len(root_ids)]) - - np.random.shuffle(root_ids) - - for root_id in root_ids[: n_tests]: - if get_edges: - dt, _ = measure_command(self.cg.get_subgraph_edges, - {"agglomeration_id": int(root_id)}) - else: - dt, _ = measure_command(self.cg.get_subgraph_nodes, - {"agglomeration_id": int(root_id)}) - - times.append(dt) - - if len(times) % 100 == 0: - print("%d / %d - %.3fms +- %.3fms" % (len(times), - n_tests, - np.mean(times) * 1000, - np.std(times) * 1000)) - - if get_edges: - np.save(self.dir + "/all_leave_edges_single_%d.npy" % (n_tests), - times) - else: - np.save(self.dir + "/all_leave_nodes_single_%d.npy" % (n_tests), - times) - return times - - - - -# class Client(threading.Thread): diff --git a/pychunkedgraph/benchmarking/timings.py b/pychunkedgraph/benchmarking/timings.py deleted file mode 100644 index d9dd603cc..000000000 --- a/pychunkedgraph/benchmarking/timings.py +++ /dev/null @@ -1,518 +0,0 @@ -import matplotlib as mpl - -try: - mpl.use('Agg') -except: - pass - -import numpy as np -import itertools -import time -import os -import h5py -from functools import lru_cache -import pickle as pkl -from matplotlib import pyplot as plt -import glob - -from pychunkedgraph.backend import chunkedgraph -from pychunkedgraph.backend.utils import column_keys - -from multiwrapper import multiprocessing_utils as mu - - -HOME = os.path.expanduser("~") - - -@lru_cache(maxsize=None) -def load_l1_l2_stats(save_folder): - with h5py.File(f"{save_folder}/l1_l2_stats.h5", "r") as f: - rep_l1_nodes = f["rep_l1_nodes"].value - n_nodes_per_l2_node = f["n_nodes_per_l2_node"].value - - return rep_l1_nodes, n_nodes_per_l2_node - - -@lru_cache(maxsize=None) -def load_root_stats(save_folder): - with h5py.File(f"{save_folder}/root_stats.h5", "r") as f: - root_ids = f["root_ids"].value - n_l1_nodes_per_root = f["n_l1_nodes_per_root"].value - rep_l1_chunk_ids = f["rep_l1_chunk_ids"].value - - return root_ids, n_l1_nodes_per_root, rep_l1_chunk_ids - - -@lru_cache(maxsize=None) -def load_merge_stats(save_folder): - with h5py.File(f"{save_folder}/merge_edge_stats.h5", "r") as f: - merge_edges = f["merge_edges"].value - merge_edge_weights = f["merge_edge_weights"].value - - return merge_edges, merge_edge_weights - - -def plot_scaling(re_path, key=8): - save_dir = f"{os.path.dirname(os.path.dirname(re_path))}/scaling/" - - paths = sorted(glob.glob(re_path)) - save_name = f"{os.path.basename(os.path.dirname(paths[0]))[:-3]}_{os.path.basename(paths[0]).split('.')[0]}_key{key}" - - sizes = [] - percentiles = [] - for i_path, path in enumerate(paths): - with open(path, "rb") as f: - percentiles.append(pkl.load(f)[key]["percentiles"]) - sizes.append(i_path) - - percentiles = np.array(percentiles) * 1000 - sizes = np.array(sizes) + 2 - - plt.figure(figsize=(10, 8)) - - plt.tick_params(length=8, width=1.5, labelsize=20) - plt.axes().spines['bottom'].set_linewidth(1.5) - plt.axes().spines['left'].set_linewidth(1.5) - plt.axes().spines['right'].set_linewidth(1.5) - plt.axes().spines['top'].set_linewidth(1.5) - - plt.plot(sizes, percentiles[:, 98], marker="o", linestyle="--", lw=2, c=".6", markersize=10, label="p99") - plt.plot(sizes, percentiles[:, 94], marker="o", linestyle="--", lw=2, c=".3", markersize=10, label="p95") - plt.plot(sizes, percentiles[:, 49], marker="o", linestyle="-", lw=2, c="k", markersize=10, label="median") - plt.plot(sizes, percentiles[:, 4], marker="o", linestyle="-", lw=2, c=".3", markersize=10, label="p05") - plt.plot(sizes, percentiles[:, 0], marker="o", linestyle="-", lw=2, c=".6", markersize=10, label="p01") - - plt.ylim(0, np.max(percentiles) * 1.05) - plt.xlim(1, np.max(sizes) * 1.05) - - - plt.xlabel("Number of layers", fontsize=22) - plt.ylabel("Time (ms)", fontsize=22) - - plt.legend(frameon=False, fontsize=18, loc="upper left") - - plt.tight_layout() - - plt.savefig(f"{save_dir}/{save_name}.png", dpi=300) - plt.close() - - - -def plot_timings(path): - save_dir = os.path.dirname(path) - save_name = os.path.basename(path).split(".")[0] - - with open(path, "rb") as f: - timings = pkl.load(f) - - loads = [] - percentiles = [] - for k in timings: - percentiles.append(timings[k]["percentiles"]) - loads.append(timings[k]["requests_per_s"]) - - percentiles = np.array(percentiles) * 1000 - loads = np.array(loads) - - plt.figure(figsize=(10, 8)) - - plt.tick_params(length=8, width=1.5, labelsize=20) - plt.axes().spines['bottom'].set_linewidth(1.5) - plt.axes().spines['left'].set_linewidth(1.5) - plt.axes().spines['right'].set_linewidth(1.5) - plt.axes().spines['top'].set_linewidth(1.5) - - plt.plot(loads, percentiles[:, 98], marker="o", linestyle="--", lw=2, c=".6", markersize=10, label="p99") - plt.plot(loads, percentiles[:, 94], marker="o", linestyle="--", lw=2, c=".3", markersize=10, label="p95") - plt.plot(loads, percentiles[:, 49], marker="o", linestyle="-", lw=2, c="k", markersize=10, label="median") - plt.plot(loads, percentiles[:, 4], marker="o", linestyle="-", lw=2, c=".3", markersize=10, label="p05") - plt.plot(loads, percentiles[:, 0], marker="o", linestyle="-", lw=2, c=".6", markersize=10, label="p01") - - plt.ylim(0, np.max(percentiles) * 1.05) - plt.xlim(0, np.max(loads) * 1.05) - - plt.xlabel("Load (requests/s)", fontsize=22) - plt.ylabel("Time (ms)", fontsize=22) - - plt.legend(frameon=False, fontsize=18, loc="upper left") - - plt.tight_layout() - - plt.savefig(f"{save_dir}/{save_name}.png", dpi=300) - plt.close() - - -def plot_all_timings(save_dir=f"{HOME}/benchmarks/"): - paths = glob.glob(f"{save_dir}/*/*.pkl") - - for path in paths: - print(path) - plot_timings(path) - - -def benchmark_root_timings(table_id, save_dir=f"{HOME}/benchmarks/", - job_size=500): - save_folder = f"{save_dir}/{table_id}/" - - n_thread_list = [1, 4, 8, 16, 24, 32, 40, 48, 64] - results = {} - - for n_threads in n_thread_list: - results[n_threads] = get_root_timings(table_id, save_dir, job_size, - n_threads=n_threads) - - print(n_threads, results[n_threads]) - - with open(f"{save_folder}/root_timings_js{job_size}.pkl", "wb") as f: - pkl.dump(results, f) - - return results - - -def get_root_timings(table_id, save_dir=f"{HOME}/benchmarks/", job_size=500, - n_threads=1): - save_folder = f"{save_dir}/{table_id}/" - - rep_l1_nodes, n_nodes_per_l2_node = load_l1_l2_stats(save_folder) - - probs = n_nodes_per_l2_node / np.sum(n_nodes_per_l2_node) - - if n_threads == 1: - n_jobs = n_threads * 3 - else: - n_jobs = n_threads * 3 - - cg = chunkedgraph.ChunkedGraph(table_id) - cg_serialized_info = cg.get_serialized_info() - if n_threads > 0: - del cg_serialized_info["credentials"] - - time_start = time.time() - np.random.seed(np.int(time.time())) - - if len(rep_l1_nodes) < job_size * 64 * 3 * 10: - replace = True - else: - replace = False - - blocks = np.random.choice(rep_l1_nodes, job_size * n_jobs, p=probs, - replace=replace).reshape(n_jobs, job_size) - - multi_args = [] - for block in blocks: - multi_args.append([cg_serialized_info, block]) - print(f"Building jobs took {time.time()-time_start}s") - - time_start = time.time() - if n_threads == 1: - results = mu.multiprocess_func(_get_root_timings, - multi_args, n_threads=n_threads, - verbose=False, debug=n_threads == 1) - else: - results = mu.multisubprocess_func(_get_root_timings, - multi_args, n_threads=n_threads) - dt = time.time() - time_start - - timings = [] - for result in results: - timings.extend(result) - - percentiles = [np.percentile(timings, k) for k in range(1, 100, 1)] - mean = np.mean(timings) - std = np.std(timings) - median = np.median(timings) - - result_dict = {"percentiles": percentiles, - "p01": percentiles[0], - "p05": percentiles[4], - "p95": percentiles[94], - "p99": percentiles[98], - "mean": mean, - "std": std, - "median": median, - "total_time_s": dt, - "job_size": job_size, - "n_jobs": n_jobs, - "n_threads": n_threads, - "replace": replace, - "requests_per_s": job_size * n_jobs / dt} - - return result_dict - - -def _get_root_timings(args): - serialized_cg_info, l1_ids = args - cg = chunkedgraph.ChunkedGraph(**serialized_cg_info) - - timings = [] - for l1_id in l1_ids: - - time_start = time.time() - root = cg.get_root(l1_id) - dt = time.time() - time_start - timings.append(dt) - - return timings - - -def benchmark_subgraph_timings(table_id, save_dir=f"{HOME}/benchmarks/", - job_size=500): - save_folder = f"{save_dir}/{table_id}/" - - n_thread_list = [1, 4, 8, 16, 24, 32, 40, 48, 64] - results = {} - - for n_threads in n_thread_list: - results[n_threads] = get_subgraph_timings(table_id, save_dir, job_size, - n_threads=n_threads) - - print(n_threads, results[n_threads]) - - with open(f"{save_folder}/subgraph_timings_js{job_size}.pkl", "wb") as f: - pkl.dump(results, f) - - return results - - -def get_subgraph_timings(table_id, save_dir=f"{HOME}/benchmarks/", job_size=500, - n_threads=1): - save_folder = f"{save_dir}/{table_id}/" - - root_ids, n_l1_nodes_per_root, rep_l1_chunk_ids = load_root_stats(save_folder) - - probs = n_l1_nodes_per_root / np.sum(n_l1_nodes_per_root) - - if n_threads == 1: - n_jobs = n_threads * 3 - else: - n_jobs = n_threads * 3 - - cg = chunkedgraph.ChunkedGraph(table_id) - cg_serialized_info = cg.get_serialized_info() - if n_threads > 0: - del cg_serialized_info["credentials"] - - time_start = time.time() - order = np.arange(len(n_l1_nodes_per_root)) - - np.random.seed(np.int(time.time())) - - if len(order) < job_size * 64 * 3 * 10: - replace = True - else: - replace = False - - blocks = np.random.choice(order, job_size * n_jobs, p=probs, - replace=replace).reshape(n_jobs, job_size) - - multi_args = [] - for block in blocks: - multi_args.append([cg_serialized_info, root_ids[block], - rep_l1_chunk_ids[block]]) - print(f"Building jobs took {time.time()-time_start}s") - - time_start = time.time() - if n_threads == 1: - results = mu.multiprocess_func(_get_subgraph_timings, - multi_args, n_threads=n_threads, - verbose=False, debug=n_threads == 1) - else: - results = mu.multisubprocess_func(_get_subgraph_timings, - multi_args, n_threads=n_threads) - dt = time.time() - time_start - - timings = [] - for result in results: - timings.extend(result) - - percentiles = [np.percentile(timings, k) for k in range(1, 100, 1)] - mean = np.mean(timings) - std = np.std(timings) - median = np.median(timings) - - result_dict = {"percentiles": percentiles, - "p01": percentiles[0], - "p05": percentiles[4], - "p95": percentiles[94], - "p99": percentiles[98], - "mean": mean, - "std": std, - "median": median, - "total_time_s": dt, - "job_size": job_size, - "n_jobs": n_jobs, - "n_threads": n_threads, - "replace": replace, - "requests_per_s": job_size * n_jobs / dt} - - return result_dict - - -def _get_subgraph_timings(args): - serialized_cg_info, root_ids, rep_l1_chunk_ids = args - cg = chunkedgraph.ChunkedGraph(**serialized_cg_info) - - timings = [] - for root_id, rep_l1_chunk_id in zip(root_ids, rep_l1_chunk_ids): - bb = np.array([rep_l1_chunk_id, rep_l1_chunk_id + 1], dtype=np.int) - - time_start = time.time() - sv_ids = cg.get_subgraph_nodes(root_id, bb, bb_is_coordinate=False) - dt = time.time() - time_start - timings.append(dt) - - return timings - - -def benchmark_merge_split_timings(table_id, save_dir=f"{HOME}/benchmarks/", - job_size=250): - save_folder = f"{save_dir}/{table_id}/" - - n_thread_list = [1] - merge_results = {} - split_results = {} - - for n_threads in n_thread_list: - results = get_merge_split_timings(table_id, save_dir, job_size, - n_threads=n_threads) - merge_results[n_threads] = results[0] - split_results[n_threads] = results[1] - - print(n_threads, merge_results[n_threads]) - print(n_threads, split_results[n_threads]) - - with open(f"{save_folder}/merge_timings_js{job_size}.pkl", "wb") as f: - pkl.dump(merge_results, f) - - with open(f"{save_folder}/split_timings_js{job_size}.pkl", "wb") as f: - pkl.dump(split_results, f) - - return merge_results, split_results - - -def get_merge_split_timings(table_id, save_dir=f"{HOME}/benchmarks/", job_size=500, - n_threads=1): - save_folder = f"{save_dir}/{table_id}/" - - merge_edges, merge_edge_weights = load_merge_stats(save_folder) - - probs = merge_edge_weights / np.sum(merge_edge_weights) - - if n_threads == 1: - n_jobs = n_threads * 3 - else: - n_jobs = n_threads * 3 - - cg = chunkedgraph.ChunkedGraph(table_id) - cg_serialized_info = cg.get_serialized_info() - if n_threads > 0: - del cg_serialized_info["credentials"] - - time_start = time.time() - order = np.arange(len(merge_edges)) - - np.random.seed(np.int(time.time())) - - replace = False - - blocks = np.random.choice(order, job_size * n_jobs, p=probs, - replace=replace).reshape(n_jobs, job_size) - - multi_args = [] - for block in blocks: - multi_args.append([cg_serialized_info, merge_edges[block]]) - print(f"Building jobs took {time.time()-time_start}s") - - time_start = time.time() - if n_threads == 1: - results = mu.multiprocess_func(_get_merge_timings, - multi_args, n_threads=n_threads, - verbose=False, debug=n_threads == 1) - else: - results = mu.multisubprocess_func(_get_merge_timings, - multi_args, n_threads=n_threads) - dt = time.time() - time_start - - timings = [] - for result in results: - timings.extend(result[0]) - - percentiles = [np.percentile(timings, k) for k in range(1, 100, 1)] - mean = np.mean(timings) - std = np.std(timings) - median = np.median(timings) - - merge_results = {"percentiles": percentiles, - "p01": percentiles[0], - "p05": percentiles[4], - "p95": percentiles[94], - "p99": percentiles[98], - "mean": mean, - "std": std, - "median": median, - "total_time_s": dt, - "job_size": job_size, - "n_jobs": n_jobs, - "n_threads": n_threads, - "replace": replace, - "requests_per_s": job_size * n_jobs / dt} - - timings = [] - for result in results: - timings.extend(result[1]) - - percentiles = [np.percentile(timings, k) for k in range(1, 100, 1)] - mean = np.mean(timings) - std = np.std(timings) - median = np.median(timings) - - split_results = {"percentiles": percentiles, - "p01": percentiles[0], - "p05": percentiles[4], - "p95": percentiles[94], - "p99": percentiles[98], - "mean": mean, - "std": std, - "median": median, - "total_time_s": dt, - "job_size": job_size, - "n_jobs": n_jobs, - "n_threads": n_threads, - "replace": replace, - "requests_per_s": job_size * n_jobs / dt} - - return merge_results, split_results - - -def _get_merge_timings(args): - serialized_cg_info, merge_edges = args - cg = chunkedgraph.ChunkedGraph(**serialized_cg_info) - - merge_timings = [] - for merge_edge in merge_edges: - time_start = time.time() - root_ids = cg.add_edges(user_id="ChuckNorris", - atomic_edges=[merge_edge]).new_root_ids - dt = time.time() - time_start - merge_timings.append(dt) - - split_timings = [] - for merge_edge in merge_edges: - time_start = time.time() - root_ids = cg.remove_edges(user_id="ChuckNorris", - atomic_edges=[merge_edge], - mincut=False).new_root_ids - - dt = time.time() - time_start - split_timings.append(dt) - - return merge_timings, split_timings - - - -def run_timings(table_id, save_dir=f"{HOME}/benchmarks/", job_size=500): - benchmark_root_timings(table_id=table_id, save_dir=save_dir, - job_size=job_size) - benchmark_subgraph_timings(table_id=table_id, save_dir=save_dir, - job_size=job_size) - diff --git a/pychunkedgraph/creator/buildgraph.md b/pychunkedgraph/creator/buildgraph.md deleted file mode 100644 index bfefb2742..000000000 --- a/pychunkedgraph/creator/buildgraph.md +++ /dev/null @@ -1,53 +0,0 @@ -# Creating a ChunkedGraph - -There are two steps to creating a ChunkedGraph for a region graph: - -0. Creating the table and BigTable family -1. Downloading files from `cloudvolume` and storing them on disk -2. Creating the ChunkedGraph from these files - -## Creating the table and family - -Deleting the current table: - -``` -from src.pychunkedgraph import chunkedgraph -cg = chunkedgraph.ChunkedGraph(table_id="mytableid") - -cg.table.delete() -``` - -Creating a new table and family: - -``` -cg = chunkedgraph.ChunkedGraph(table_id="mytableid") - -cg.table.create() -f = cg.table.column_family(cg.family_id) -f.create() -``` - -## Downloading all files from cloudvolume - -To download all relevant friles from a cloudvolume directory do - -``` -from src.pychunkedgraph import chunkcreator - -chunkcreator.download_and_store_cv_files(cv_url) -``` -The files are stored as h5's in a directory in `home`. The directory name is chosen to be the layer name. - - -## Building the ChunkedGraph - -``` -chunkcreator.create_chunked_graph(cv_url, table_id="mytableid", nb_cpus=1) -``` - -`nb_cpus` -allows the user to run this process in parallel using `subprocesses` (see [multiprocessing.md](https://github.com/seung-lab/PyChunkedGraph/blob/master/src/pychunkedgraph/multiprocessing.md)). - - - - diff --git a/pychunkedgraph/creator/chunkcreator.py b/pychunkedgraph/creator/chunkcreator.py deleted file mode 100644 index 58a038160..000000000 --- a/pychunkedgraph/creator/chunkcreator.py +++ /dev/null @@ -1,489 +0,0 @@ -import glob -import numpy as np -import os -import re -import time -import itertools -import random - -from cloudvolume import storage - -# from chunkedgraph import ChunkedGraph -import pychunkedgraph.backend.chunkedgraph_utils -from pychunkedgraph.backend import chunkedgraph -from multiwrapper import multiprocessing_utils as mu -from pychunkedgraph.creator import creator_utils - - -def download_and_store_cv_files(dataset_name="basil", - n_threads=10, olduint32=False): - """ Downloads files from google cloud using cloud-volume - - :param dataset_name: str - :param n_threads: int - :param olduint32: bool - """ - if "basil" == dataset_name: - cv_url = "gs://nkem/basil_4k_oldnet/region_graph/" - elif "pinky40" == dataset_name: - cv_url = "gs://nkem/pinky40_v11/mst_trimmed_sem_remap/region_graph/" - elif "pinky100" == dataset_name: - cv_url = "gs://nkem/pinky100_v0/region_graph/" - else: - raise Exception("Could not identify region graph ressource") - - with storage.SimpleStorage(cv_url) as cv_st: - dir_path = creator_utils.dir_from_layer_name( - creator_utils.layer_name_from_cv_url(cv_st.layer_path)) - - if not os.path.exists(dir_path): - os.makedirs(dir_path) - - file_paths = list(cv_st.list_files()) - - file_chunks = np.array_split(file_paths, n_threads * 3) - multi_args = [] - for i_file_chunk, file_chunk in enumerate(file_chunks): - multi_args.append([i_file_chunk, cv_url, file_chunk, olduint32]) - - # Run parallelizing - if n_threads == 1: - mu.multiprocess_func(_download_and_store_cv_files_thread, - multi_args, n_threads=n_threads, - verbose=True, debug=n_threads==1) - else: - mu.multisubprocess_func(_download_and_store_cv_files_thread, - multi_args, n_threads=n_threads) - - -def _download_and_store_cv_files_thread(args): - """ Helper thread to download files from google cloud """ - chunk_id, cv_url, file_paths, olduint32 = args - - # Reset connection pool to make cloud-volume compatible with parallelizing - storage.reset_connection_pools() - - n_file_paths = len(file_paths) - time_start = time.time() - with storage.SimpleStorage(cv_url) as cv_st: - for i_fp, fp in enumerate(file_paths): - if i_fp % 100 == 1: - dt = time.time() - time_start - eta = dt / i_fp * n_file_paths - dt - print("%d: %d / %d - dt: %.3fs - eta: %.3fs" % ( - chunk_id, i_fp, n_file_paths, dt, eta)) - - creator_utils.download_and_store_edge_file(cv_st, fp) - - -def check_stored_cv_files(dataset_name="basil"): - """ Tests if all files were downloaded - - :param dataset_name: str - """ - if "basil" == dataset_name: - cv_url = "gs://nkem/basil_4k_oldnet/region_graph/" - elif "pinky40" == dataset_name: - cv_url = "gs://nkem/pinky40_v11/mst_trimmed_sem_remap/region_graph/" - elif "pinky100" == dataset_name: - cv_url = "gs://nkem/pinky100_v0/region_graph/" - else: - raise Exception("Could not identify region graph ressource") - - with storage.SimpleStorage(cv_url) as cv_st: - dir_path = creator_utils.dir_from_layer_name( - creator_utils.layer_name_from_cv_url(cv_st.layer_path)) - - file_paths = list(cv_st.list_files()) - - c = 0 - n_file_paths = len(file_paths) - time_start = time.time() - for i_fp, fp in enumerate(file_paths): - if i_fp % 1000 == 1: - dt = time.time() - time_start - eta = dt / i_fp * n_file_paths - dt - print("%d / %d - dt: %.3fs - eta: %.3fs" % ( - i_fp, n_file_paths, dt, eta)) - - if not os.path.exists(dir_path + fp[:-4] + ".h5"): - print(dir_path + fp[:-4] + ".h5") - c += 1 - - print("%d files were missing" % c) - - -def _sort_arrays(coords, paths): - sorting = np.lexsort((coords[..., 2], coords[..., 1], coords[..., 0])) - return coords[sorting], paths[sorting] - -def create_chunked_graph(table_id=None, cv_url=None, ws_url=None, fan_out=2, - bbox=None, chunk_size=(512, 512, 128), verbose=False, - n_threads=1): - """ Creates chunked graph from downloaded files - - :param table_id: str - :param cv_url: str - :param ws_url: str - :param fan_out: int - :param bbox: [[x_, y_, z_], [_x, _y, _z]] - :param chunk_size: tuple - :param verbose: bool - :param n_threads: int - """ - if cv_url is None or ws_url is None: - if "basil" in table_id: - cv_url = "gs://nkem/basil_4k_oldnet/region_graph/" - ws_url = "gs://neuroglancer/svenmd/basil_4k_oldnet_cg/watershed/" - elif "pinky40" in table_id: - cv_url = "gs://nkem/pinky40_v11/mst_trimmed_sem_remap/region_graph/" - ws_url = "gs://neuroglancer/svenmd/pinky40_v11/watershed/" - elif "pinky100" in table_id: - cv_url = "gs://nkem/pinky100_v0/region_graph/" - ws_url = "gs://neuroglancer/nkem/pinky100_v0/ws/lost_no-random/bbox1_0/" - else: - raise Exception("Could not identify region graph ressource") - - times = [] - time_start = time.time() - - chunk_size = np.array(list(chunk_size)) - - file_paths = np.sort(glob.glob(creator_utils.dir_from_layer_name( - creator_utils.layer_name_from_cv_url(cv_url)) + "/*")) - - file_path_blocks = np.array_split(file_paths, n_threads * 3) - - multi_args = [] - for fp_block in file_path_blocks: - multi_args.append([fp_block, table_id, chunk_size, bbox]) - - if n_threads == 1: - results = mu.multiprocess_func( - _preprocess_chunkedgraph_data_thread, multi_args, - n_threads=n_threads, - verbose=True, debug=n_threads == 1) - else: - results = mu.multisubprocess_func( - _preprocess_chunkedgraph_data_thread, multi_args, - n_threads=n_threads) - - in_chunk_connected_paths = np.array([]) - in_chunk_connected_ids = np.array([], dtype=np.uint64).reshape(-1, 3) - in_chunk_disconnected_paths = np.array([]) - in_chunk_disconnected_ids = np.array([], dtype=np.uint64).reshape(-1, 3) - between_chunk_paths = np.array([]) - between_chunk_ids = np.array([], dtype=np.uint64).reshape(-1, 2, 3) - isolated_paths = np.array([]) - isolated_ids = np.array([], dtype=np.uint64).reshape(-1, 3) - - for result in results: - in_chunk_connected_paths = np.concatenate([in_chunk_connected_paths, result[0]]) - in_chunk_connected_ids = np.concatenate([in_chunk_connected_ids, result[1]]) - in_chunk_disconnected_paths = np.concatenate([in_chunk_disconnected_paths, result[2]]) - in_chunk_disconnected_ids = np.concatenate([in_chunk_disconnected_ids, result[3]]) - between_chunk_paths = np.concatenate([between_chunk_paths, result[4]]) - between_chunk_ids = np.concatenate([between_chunk_ids, result[5]]) - isolated_paths = np.concatenate([isolated_paths, result[6]]) - isolated_ids = np.concatenate([isolated_ids, result[7]]) - - assert len(in_chunk_connected_ids) == len(in_chunk_connected_paths) == \ - len(in_chunk_disconnected_ids) == len(in_chunk_disconnected_paths) == \ - len(isolated_ids) == len(isolated_paths) - - in_chunk_connected_ids, in_chunk_connected_paths = \ - _sort_arrays(in_chunk_connected_ids, in_chunk_connected_paths) - - in_chunk_disconnected_ids, in_chunk_disconnected_paths = \ - _sort_arrays(in_chunk_disconnected_ids, in_chunk_disconnected_paths) - - isolated_ids, isolated_paths = \ - _sort_arrays(isolated_ids, isolated_paths) - - times.append(["Preprocessing", time.time() - time_start]) - - print("Preprocessing took %.3fs = %.2fh" % (times[-1][1], times[-1][1]/3600)) - - time_start = time.time() - - multi_args = [] - - in_chunk_id_blocks = np.array_split(in_chunk_connected_ids, max(1, n_threads)) - cumsum = 0 - - for in_chunk_id_block in in_chunk_id_blocks: - multi_args.append([between_chunk_ids, between_chunk_paths, - in_chunk_id_block, cumsum]) - cumsum += len(in_chunk_id_block) - - # Run parallelizing - if n_threads == 1: - results = mu.multiprocess_func( - _between_chunk_masks_thread, multi_args, n_threads=n_threads, - verbose=True, debug=n_threads == 1) - else: - results = mu.multisubprocess_func( - _between_chunk_masks_thread, multi_args, n_threads=n_threads) - - times.append(["Data sorting", time.time() - time_start]) - - print("Data sorting took %.3fs = %.2fh" % (times[-1][1], times[-1][1]/3600)) - - time_start = time.time() - - n_layers = int(np.ceil(pychunkedgraph.backend.chunkedgraph_utils.log_n(np.max(in_chunk_connected_ids) + 1, fan_out))) + 2 - - print("N layers: %d" % n_layers) - - cg = chunkedgraph.ChunkedGraph(table_id=table_id, n_layers=np.uint64(n_layers), - fan_out=np.uint64(fan_out), - chunk_size=np.array(chunk_size, dtype=np.uint64), - cv_path=ws_url, is_new=True) - - # Fill lowest layer and create first abstraction layer - # Create arguments for parallelizing - - multi_args = [] - for result in results: - offset, between_chunk_paths_out_masked, between_chunk_paths_in_masked = result - - for i_chunk in range(len(between_chunk_paths_out_masked)): - multi_args.append([table_id, - in_chunk_connected_paths[offset + i_chunk], - in_chunk_disconnected_paths[offset + i_chunk], - isolated_paths[offset + i_chunk], - between_chunk_paths_in_masked[i_chunk], - between_chunk_paths_out_masked[i_chunk], - verbose]) - - random.shuffle(multi_args) - - print("%d jobs for creating layer 1 + 2" % len(multi_args)) - - # Run parallelizing - if n_threads == 1: - mu.multiprocess_func( - _create_atomic_layer_thread, multi_args, n_threads=n_threads, - verbose=True, debug=n_threads == 1) - else: - mu.multisubprocess_func( - _create_atomic_layer_thread, multi_args, n_threads=n_threads) - - times.append(["Layers 1 + 2", time.time() - time_start]) - - # Fill higher abstraction layers - child_chunk_ids = in_chunk_connected_ids.copy() - for layer_id in range(3, n_layers + 1): - - time_start = time.time() - - print("\n\n\n --- LAYER %d --- \n\n\n" % layer_id) - - parent_chunk_ids = child_chunk_ids // cg.fan_out - parent_chunk_ids = parent_chunk_ids.astype(np.int) - - u_pcids, inds = np.unique(parent_chunk_ids, - axis=0, return_inverse=True) - - if len(u_pcids) > n_threads: - n_threads_per_process = 1 - else: - n_threads_per_process = int(np.ceil(n_threads / len(u_pcids))) - - multi_args = [] - for ind in range(len(u_pcids)): - multi_args.append([table_id, layer_id, - child_chunk_ids[inds == ind].astype(np.int), - n_threads_per_process]) - - child_chunk_ids = u_pcids - - # Run parallelizing - if n_threads == 1: - mu.multiprocess_func( - _add_layer_thread, multi_args, n_threads=n_threads, - verbose=True, - debug=n_threads == 1) - else: - mu.multisubprocess_func( - _add_layer_thread, multi_args, n_threads=n_threads, - suffix=str(layer_id)) - - times.append(["Layer %d" % layer_id, time.time() - time_start]) - - for time_entry in times: - print("%s: %.2fs = %.2fmin = %.2fh" % (time_entry[0], time_entry[1], - time_entry[1] / 60, - time_entry[1] / 3600)) - - -def _preprocess_chunkedgraph_data_thread(args): - """ Reads downloaded files and sorts them in _in_ and _between_ chunks """ - - file_paths, table_id, chunk_size, bbox = args - - if bbox is None: - bbox = [[0, 0, 0], [np.inf, np.inf, np.inf]] - - bbox = np.array(bbox) - - in_chunk_connected_paths = np.array([]) - in_chunk_connected_ids = np.array([], dtype=np.uint64).reshape(-1, 3) - in_chunk_disconnected_paths = np.array([]) - in_chunk_disconnected_ids = np.array([], dtype=np.uint64).reshape(-1, 3) - between_chunk_paths = np.array([]) - between_chunk_ids = np.array([], dtype=np.uint64).reshape(-1, 2, 3) - isolated_paths = np.array([]) - isolated_ids = np.array([], dtype=np.uint64).reshape(-1, 3) - - # Read file paths - gather chunk ids and in / out properties - for i_fp, fp in enumerate(file_paths): - file_name = os.path.basename(fp).split(".")[0] - - # Read coordinates from file path - x1, x2, y1, y2, z1, z2 = np.array(re.findall("[\d]+", file_name), dtype=np.int)[:6] - - if np.any((bbox[0] - np.array([x2, y2, z2])) >= 0) or \ - np.any((bbox[1] - np.array([x1, y1, z1])) <= 0): - continue - - dx = x2 - x1 - dy = y2 - y1 - dz = z2 - z1 - - d = np.array([dx, dy, dz]) - c = np.array([x1, y1, z1]) - - # if there is a 2 in d then the file contains edges that cross chunks - gap = 2 - - if gap in d: - s_c = np.where(d == gap)[0] - chunk_coord = c.copy() - - chunk1_id = np.array(chunk_coord / chunk_size, dtype=np.int) - chunk_coord[s_c] += chunk_size[s_c] - chunk2_id = np.array(chunk_coord / chunk_size, dtype=np.int) - - between_chunk_ids = np.concatenate([between_chunk_ids, - np.array([chunk1_id, chunk2_id])[None]]) - between_chunk_paths = np.concatenate([between_chunk_paths, [fp]]) - else: - chunk_coord = np.array(c / chunk_size, dtype=np.int) - - if "disconnected" in file_name: - in_chunk_disconnected_ids = np.concatenate([in_chunk_disconnected_ids, chunk_coord[None]]) - in_chunk_disconnected_paths = np.concatenate([in_chunk_disconnected_paths, [fp]]) - elif "isolated" in file_name: - isolated_ids = np.concatenate([isolated_ids, chunk_coord[None]]) - isolated_paths = np.concatenate([isolated_paths, [fp]]) - else: - in_chunk_connected_ids = np.concatenate([in_chunk_connected_ids, chunk_coord[None]]) - in_chunk_connected_paths = np.concatenate([in_chunk_connected_paths, [fp]]) - - return in_chunk_connected_paths, in_chunk_connected_ids, \ - in_chunk_disconnected_paths, in_chunk_disconnected_ids, \ - between_chunk_paths, between_chunk_ids, \ - isolated_paths, isolated_ids - - -def _between_chunk_masks_thread(args): - """""" - between_chunk_ids, between_chunk_paths, in_chunk_id_block, offset = args - - between_chunk_paths_out_masked = [] - between_chunk_paths_in_masked = [] - - for i_in_chunk_id, in_chunk_id in enumerate(in_chunk_id_block): - out_paths_mask = np.sum(np.abs(between_chunk_ids[:, 0] - in_chunk_id), axis=1) == 0 - in_paths_masks = np.sum(np.abs(between_chunk_ids[:, 1] - in_chunk_id), axis=1) == 0 - - between_chunk_paths_out_masked.append(between_chunk_paths[out_paths_mask]) - between_chunk_paths_in_masked.append(between_chunk_paths[in_paths_masks]) - - return offset, between_chunk_paths_out_masked, between_chunk_paths_in_masked - - -def _create_atomic_layer_thread(args): - """ Fills lowest layer and create first abstraction layer """ - # Load args - table_id, chunk_connected_path, chunk_disconnected_path, isolated_path,\ - in_paths, out_paths, verbose = args - - # Load edge information - edge_ids = {"in_connected": np.array([], dtype=np.uint64).reshape(0, 2), - "in_disconnected": np.array([], dtype=np.uint64).reshape(0, 2), - "cross": np.array([], dtype=np.uint64).reshape(0, 2), - "between_connected": np.array([], dtype=np.uint64).reshape(0, 2), - "between_disconnected": np.array([], dtype=np.uint64).reshape(0, 2)} - edge_affs = {"in_connected": np.array([], dtype=np.float32), - "in_disconnected": np.array([], dtype=np.float32), - "between_connected": np.array([], dtype=np.float32), - "between_disconnected": np.array([], dtype=np.float32)} - edge_areas = {"in_connected": np.array([], dtype=np.float32), - "in_disconnected": np.array([], dtype=np.float32), - "between_connected": np.array([], dtype=np.float32), - "between_disconnected": np.array([], dtype=np.float32)} - - in_connected_dict = creator_utils.read_edge_file_h5(chunk_connected_path) - in_disconnected_dict = creator_utils.read_edge_file_h5(chunk_disconnected_path) - - edge_ids["in_connected"] = in_connected_dict["edge_ids"] - edge_affs["in_connected"] = in_connected_dict["edge_affs"] - edge_areas["in_connected"] = in_connected_dict["edge_areas"] - - edge_ids["in_disconnected"] = in_disconnected_dict["edge_ids"] - edge_affs["in_disconnected"] = in_disconnected_dict["edge_affs"] - edge_areas["in_disconnected"] = in_disconnected_dict["edge_areas"] - - if os.path.exists(isolated_path): - isolated_ids = creator_utils.read_edge_file_h5(isolated_path)["node_ids"] - else: - isolated_ids = np.array([], dtype=np.uint64) - - for fp in in_paths: - edge_dict = creator_utils.read_edge_file_h5(fp) - - # Cross edges are always ordered to point OUT of the chunk - if "unbreakable" in fp: - edge_ids["cross"] = np.concatenate([edge_ids["cross"], edge_dict["edge_ids"][:, [1, 0]]]) - elif "disconnected" in fp: - edge_ids["between_disconnected"] = np.concatenate([edge_ids["between_disconnected"], edge_dict["edge_ids"][:, [1, 0]]]) - edge_affs["between_disconnected"] = np.concatenate([edge_affs["between_disconnected"], edge_dict["edge_affs"]]) - edge_areas["between_disconnected"] = np.concatenate([edge_areas["between_disconnected"], edge_dict["edge_areas"]]) - else: - # connected - edge_ids["between_connected"] = np.concatenate([edge_ids["between_connected"], edge_dict["edge_ids"][:, [1, 0]]]) - edge_affs["between_connected"] = np.concatenate([edge_affs["between_connected"], edge_dict["edge_affs"]]) - edge_areas["between_connected"] = np.concatenate([edge_areas["between_connected"], edge_dict["edge_areas"]]) - - for fp in out_paths: - edge_dict = creator_utils.read_edge_file_h5(fp) - - if "unbreakable" in fp: - edge_ids["cross"] = np.concatenate([edge_ids["cross"], edge_dict["edge_ids"]]) - elif "disconnected" in fp: - edge_ids["between_disconnected"] = np.concatenate([edge_ids["between_disconnected"], edge_dict["edge_ids"]]) - edge_affs["between_disconnected"] = np.concatenate([edge_affs["between_disconnected"], edge_dict["edge_affs"]]) - edge_areas["between_disconnected"] = np.concatenate([edge_areas["between_disconnected"], edge_dict["edge_areas"]]) - else: - # connected - edge_ids["between_connected"] = np.concatenate([edge_ids["between_connected"], edge_dict["edge_ids"]]) - edge_affs["between_connected"] = np.concatenate([edge_affs["between_connected"], edge_dict["edge_affs"]]) - edge_areas["between_connected"] = np.concatenate([edge_areas["between_connected"], edge_dict["edge_areas"]]) - - # Initialize an ChunkedGraph instance and write to it - cg = chunkedgraph.ChunkedGraph(table_id=table_id) - - cg.add_atomic_edges_in_chunks(edge_ids, edge_affs, edge_areas, - isolated_node_ids=isolated_ids, - verbose=verbose) - - -def _add_layer_thread(args): - """ Creates abstraction layer """ - table_id, layer_id, chunk_coords, n_threads_per_process = args - - cg = chunkedgraph.ChunkedGraph(table_id=table_id) - cg.add_layer(layer_id, chunk_coords, n_threads=n_threads_per_process) - diff --git a/pychunkedgraph/creator/creator_utils.py b/pychunkedgraph/creator/creator_utils.py deleted file mode 100644 index 4bca81629..000000000 --- a/pychunkedgraph/creator/creator_utils.py +++ /dev/null @@ -1,116 +0,0 @@ -import numpy as np -import h5py -import os - -HOME = os.path.expanduser("~") - - -def layer_name_from_cv_url(cv_url): - return cv_url.strip("/").split("/")[-2] - - -def dir_from_layer_name(layer_name): - return HOME + "/" + layer_name + "/" - - -def read_edge_file_cv(cv_st, path): - """ Reads the edge ids and affinities from an edge file """ - - if 'unbreakable' in path: - dt = 'uint64, uint64' - elif 'isolated' in path: - dt = 'uint64' - else: - dt = 'uint64, uint64, float32, uint64' - - buf = cv_st.get_file(path) - edge_data = np.frombuffer(buf, dtype=dt) - - if len(edge_data) == 0: - if len(dt.split(",")) == 1: - edge_data = np.array([], dtype=np.uint64) - else: - edge_data = {"f0": np.array([], dtype=np.uint64), - "f1": np.array([], dtype=np.uint64), - "f2": np.array([], dtype=np.float32), - "f3": np.array([], dtype=np.uint64)} - - if 'isolated' in path: - edge_dict = {"node_ids": edge_data} - else: - edge_ids = np.concatenate([edge_data["f0"].reshape(-1, 1), - edge_data["f1"].reshape(-1, 1)], axis=1) - - edge_dict = {"edge_ids": edge_ids} - - if "connected" in path: - edge_dict['edge_affs'] = edge_data['f2'] - edge_dict['edge_areas'] = edge_data['f3'] - - return edge_dict - - -def read_edge_file_h5(path, layer_name=None): - if not path.endswith(".h5"): - path = path[:-4] + ".h5" - - if layer_name is not None: - path = dir_from_layer_name(layer_name) + path - - edge_dict = {} - with h5py.File(path, "r") as f: - for k in f.keys(): - edge_dict[k] = f[k].value - - return edge_dict - - -def download_and_store_edge_file(cv_st, path, create_dir=True): - edge_dict = read_edge_file_cv(cv_st, path) - - dir_path = dir_from_layer_name(layer_name_from_cv_url(cv_st.layer_path)) - - if not os.path.exists(dir_path) and create_dir: - os.makedirs(dir_path) - - with h5py.File(dir_path + path[:-4] + ".h5", "w") as f: - for k in edge_dict.keys(): - f.create_dataset(k, data=edge_dict[k], compression="gzip") - - -# def read_mapping_cv(cv_st, path, olduint32=False): -# """ Reads the mapping information from a file """ -# -# if olduint32: -# mapping = np.frombuffer(cv_st.get_file(path), -# dtype=np.uint64).reshape(-1, 2) -# mapping_to = mapping[:, 1] -# mapping_from = np.frombuffer(np.ascontiguousarray(mapping[:, 0]), dtype=np.uint32)[::2].astype(np.uint64) -# return np.concatenate([mapping_from[:, None], mapping_to[:, None]], axis=1) -# else: -# return np.frombuffer(cv_st.get_file(path), dtype=np.uint64).reshape(-1, 2) -# -# -# def read_mapping_h5(path, layer_name=None): -# if not path.endswith(".h5"): -# path = path[:-4] + ".h5" -# -# if layer_name is not None: -# path = dir_from_layer_name(layer_name) + path -# -# with h5py.File(path, "r") as f: -# mapping = f["mapping"].value -# -# return mapping -# -# -# def download_and_store_mapping_file(cv_st, path, olduint32=False): -# mapping = read_mapping_cv(cv_st, path, olduint32=olduint32) -# -# dir_path = dir_from_layer_name(layer_name_from_cv_url(cv_st.layer_path)) -# -# if not os.path.exists(dir_path): -# os.makedirs(dir_path) -# -# with h5py.File(dir_path + path[:-4] + ".h5", "w") as f: -# f["mapping"] = mapping \ No newline at end of file diff --git a/pychunkedgraph/creator/data_test.py b/pychunkedgraph/creator/data_test.py deleted file mode 100644 index 48af7b51b..000000000 --- a/pychunkedgraph/creator/data_test.py +++ /dev/null @@ -1,51 +0,0 @@ -import glob -import collections -import numpy as np - -from pychunkedgraph.creator import creator_utils -from multiwrapper import multiprocessing_utils as mu - - -def _test_unique_edge_assignment_thread(args): - paths = args[0] - - id_dict = collections.Counter() - for path in paths: - try: - ids = creator_utils.read_edge_file_h5(path)["edge_ids"] - except: - ids = creator_utils.read_edge_file_h5(path)["node_ids"] - u_ids = np.unique(ids) - - u_id_d = dict(zip(u_ids, np.ones(len(u_ids), dtype=np.int))) - add_counter = collections.Counter(u_id_d) - - id_dict += add_counter - - # return np.array(list(id_dict.items())) - return id_dict - - -def test_unique_edge_assignment(dir, n_threads=128): - file_paths = glob.glob(dir + "/*") - - file_chunks = np.array_split(file_paths, n_threads * 3) - multi_args = [] - for i_file_chunk, file_chunk in enumerate(file_chunks): - multi_args.append([file_chunk]) - - # Run parallelizing - if n_threads == 1: - results = mu.multiprocess_func(_test_unique_edge_assignment_thread, - multi_args, n_threads=n_threads, - verbose=True, debug=n_threads==1) - else: - results = mu.multiprocess_func(_test_unique_edge_assignment_thread, - multi_args, n_threads=n_threads) - - id_dict = collections.Counter() - for result in results: - # id_dict += collections.Counter(dict(result)) - id_dict += result - - return id_dict \ No newline at end of file diff --git a/pychunkedgraph/creator/graph_tests.py b/pychunkedgraph/creator/graph_tests.py deleted file mode 100644 index 54d51411a..000000000 --- a/pychunkedgraph/creator/graph_tests.py +++ /dev/null @@ -1,133 +0,0 @@ -import itertools -import numpy as np -import time - -import pychunkedgraph.backend.chunkedgraph_utils -from pychunkedgraph.backend.utils import column_keys -from pychunkedgraph.backend import chunkedgraph -from multiwrapper import multiprocessing_utils as mu - - -def _family_consistency_test_thread(args): - """ Helper to test family consistency """ - - table_id, coord, layer_id = args - - x, y, z = coord - - cg = chunkedgraph.ChunkedGraph(table_id) - - rows = cg.range_read_chunk(layer_id, x, y, z) - parent_column = column_keys.Hierarchy.Parent - - failed_node_ids = [] - - time_start = time.time() - for i_k, node_id in enumerate(rows.keys()): - if i_k % 100 == 1: - dt = time.time() - time_start - eta = dt / i_k * len(rows) - dt - print("%d / %d - %.3fs -> %.3fs " % (i_k, len(rows), dt, eta), - end="\r") - - parent_id = rows[node_id][parent_column][0].value - - if node_id not in cg.get_children(parent_id): - failed_node_ids.append([node_id, parent_id]) - - return failed_node_ids - - -def family_consistency_test(table_id, n_threads=64): - """ Runs a simple test on the WHOLE graph - - tests: id in children(parent(id)) - - :param table_id: str - :param n_threads: int - :return: dict - n x 2 per layer - each failed pair: (node_id, parent_id) - """ - - cg = chunkedgraph.ChunkedGraph(table_id) - - failed_node_id_dict = {} - for layer_id in range(1, cg.n_layers): - print("\n\n Layer %d \n\n" % layer_id) - - step = int(cg.fan_out ** np.max([0, layer_id - 2])) - coords = list(itertools.product(range(0, 8, step), - range(0, 8, step), - range(0, 4, step))) - - multi_args = [] - for coord in coords: - multi_args.append([table_id, coord, layer_id]) - - collected_failed_node_ids = mu.multisubprocess_func( - _family_consistency_test_thread, multi_args, n_threads=n_threads) - - failed_node_ids = [] - for _failed_node_ids in collected_failed_node_ids: - failed_node_ids.extend(_failed_node_ids) - - failed_node_id_dict[layer_id] = np.array(failed_node_ids) - - print("\n%d nodes rows failed\n" % len(failed_node_ids)) - - return failed_node_id_dict - - -def children_test(table_id, layer, coord_list): - - cg = chunkedgraph.ChunkedGraph(table_id) - child_column = column_keys.Hierarchy.Child - - for coords in coord_list: - x, y, z = coords - - node_ids = cg.range_read_chunk(layer, x, y, z, columns=child_column) - all_children = [] - children_chunks = [] - for children in node_ids.values(): - children = children[0].value - for child in children: - all_children.append(child) - children_chunks.append(cg.get_chunk_id(child)) - - u_children_chunks, c_children_chunks = np.unique(children_chunks, - return_counts=True) - u_chunk_coords = [cg.get_chunk_coordinates(c) for c in u_children_chunks] - - print("\n--- Layer %d ---- [%d, %d, %d] ---" % (layer, x, y, z)) - print("N(all children): %d" % len(all_children)) - print("N(unique children): %d" % len(np.unique(all_children))) - print("N(unique children chunks): %d" % len(u_children_chunks)) - print("Unique children chunk coords", u_chunk_coords) - print("N(ids per unique children chunk):", c_children_chunks) - - -def root_cross_edge_test(node_id, table_id=None, cg=None): - if cg is None: - assert isinstance(table_id, str) - cg = chunkedgraph.ChunkedGraph(table_id) - - cross_edge_dict_layers = {} - cross_edge_dict_children = {} - for layer in range(2, cg.n_layers): - child_ids = cg.get_subgraph_nodes(node_id, return_layers=[layer]) - - cross_edge_dict = {} - child_reference_ids = [] - for child_id in child_ids: - cross_edge_dict = pychunkedgraph.backend.chunkedgraph_utils.combine_cross_chunk_edge_dicts(cross_edge_dict, cg.read_cross_chunk_edges(child_id)) - - cross_edge_dict_layers[layer] = cross_edge_dict - - for layer in cross_edge_dict_layers.keys(): - print("\n--------\n") - for i_layer in cross_edge_dict_layers[layer].keys(): - print(layer, i_layer, len(cross_edge_dict_layers[layer][i_layer])) - - return cross_edge_dict_layers diff --git a/pychunkedgraph/debug/__init__.py b/pychunkedgraph/debug/__init__.py new file mode 100644 index 000000000..4be0ff102 --- /dev/null +++ b/pychunkedgraph/debug/__init__.py @@ -0,0 +1,3 @@ +""" +Modules for debugging. +""" \ No newline at end of file diff --git a/pychunkedgraph/debug/cross_edge_test.py b/pychunkedgraph/debug/cross_edge_test.py new file mode 100644 index 000000000..25bacfa0b --- /dev/null +++ b/pychunkedgraph/debug/cross_edge_test.py @@ -0,0 +1,60 @@ +import os +from datetime import datetime +import numpy as np + +from pychunkedgraph.graph import chunkedgraph +from pychunkedgraph.graph import attributes + +#os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/home/svenmd/.cloudvolume/secrets/google-secret.json" + +layer = 2 +n_chunks = 1000 +n_segments_per_chunk = 200 +# timestamp = datetime.datetime.fromtimestamp(1588875769) +timestamp = datetime.utcnow() + +cg = chunkedgraph.ChunkedGraph(graph_id="pinky_nf_v2") + +np.random.seed(42) + +node_ids = [] +for _ in range(n_chunks): + c_x = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][0]) + c_y = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][1]) + c_z = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][2]) + + chunk_id = cg.get_chunk_id(layer=layer, x=c_x, y=c_y, z=c_z) + + max_segment_id = cg.get_segment_id(cg.id_client.get_max_node_id(chunk_id)) + + if max_segment_id < 10: + continue + + segment_ids = np.random.randint(1, max_segment_id, n_segments_per_chunk) + + for segment_id in segment_ids: + node_ids.append(cg.get_node_id(np.uint64(segment_id), np.uint64(chunk_id))) + +rows = cg.client.read_nodes(node_ids=node_ids, end_time=timestamp, + properties=attributes.Hierarchy.Parent) +valid_node_ids = [] +non_valid_node_ids = [] +for k in rows.keys(): + if len(rows[k]) > 0: + valid_node_ids.append(k) + else: + non_valid_node_ids.append(k) + +cc_edges = cg.get_atomic_cross_edges(valid_node_ids) +cc_ids = np.unique(np.concatenate([np.concatenate(list(v.values())) for v in list(cc_edges.values()) if len(v.values())])) + +roots = cg.get_roots(cc_ids) +root_dict = dict(zip(cc_ids, roots)) +root_dict_vec = np.vectorize(root_dict.get) + +for k in cc_edges: + if len(cc_edges[k]) == 0: + continue + local_ids = np.unique(np.concatenate(list(cc_edges[k].values()))) + + assert len(np.unique(root_dict_vec(local_ids))) \ No newline at end of file diff --git a/pychunkedgraph/debug/edges.py b/pychunkedgraph/debug/edges.py new file mode 100644 index 000000000..bdd59dadf --- /dev/null +++ b/pychunkedgraph/debug/edges.py @@ -0,0 +1,11 @@ +from time import time +from ..graph import ChunkedGraph + +cg = ChunkedGraph(graph_id="minnie3_v0") + + +def get_subgraph(node_id): + start = time() + result = cg.get_subgraph(node_id, nodes_only=True) + print("cg.get_subgraph", time() - start) + return result diff --git a/pychunkedgraph/debug/edits.py b/pychunkedgraph/debug/edits.py new file mode 100644 index 000000000..4773bc912 --- /dev/null +++ b/pychunkedgraph/debug/edits.py @@ -0,0 +1,103 @@ +from typing import Union +from typing import Tuple + +import numpy as np + +from ..graph import ChunkedGraph +from ..graph.operation import GraphEditOperation +from ..graph.operation import MergeOperation +from ..graph.operation import MulticutOperation +from ..graph.operation import SplitOperation +from ..app.app_utils import handle_supervoxel_id_lookup + +USER_ID = "debug" + + +def _parse_merge_payload( + cg: ChunkedGraph, user_id: str, payload: list +) -> MergeOperation: + + node_ids = [] + coords = [] + for node in payload: + node_ids.append(node[0]) + coords.append(np.array(node[1:]) / cg.segmentation_resolution) + + atomic_edge = handle_supervoxel_id_lookup(cg, coords, node_ids) + chunk_coord_delta = cg.get_chunk_coordinates( + atomic_edge[0] + ) - cg.get_chunk_coordinates(atomic_edge[1]) + if np.any(np.abs(chunk_coord_delta) > 3): + raise ValueError("Chebyshev distance exceeded allowed maximum.") + + return ( + node_ids, + atomic_edge, + MergeOperation( + cg, + user_id=user_id, + added_edges=np.array(atomic_edge, dtype=np.uint64), + source_coords=coords[:1], + sink_coords=coords[1:], + ), + ) + + +def _parse_split_payload( + cg: ChunkedGraph, user_id: str, payload: dict, mincut: bool = True +) -> Union[SplitOperation, MulticutOperation]: + node_idents = [] + node_ident_map = { + "sources": 0, + "sinks": 1, + } + coords = [] + node_ids = [] + + for k in ["sources", "sinks"]: + for node in payload[k]: + node_ids.append(node[0]) + coords.append(np.array(node[1:]) / cg.segmentation_resolution) + node_idents.append(node_ident_map[k]) + + node_ids = np.array(node_ids, dtype=np.uint64) + coords = np.array(coords) + node_idents = np.array(node_idents) + sv_ids = handle_supervoxel_id_lookup(cg, coords, node_ids) + + source_ids = sv_ids[node_idents == 0] + sink_ids = sv_ids[node_idents == 1] + source_coords = coords[node_idents == 0] + sink_coords = coords[node_idents == 1] + + bb_offset = (240, 240, 24) + return ( + source_ids, + sink_ids, + MulticutOperation( + cg, + user_id=user_id, + source_ids=source_ids, + sink_ids=sink_ids, + source_coords=source_coords, + sink_coords=sink_coords, + bbox_offset=bb_offset, + path_augment=True, + disallow_isolating_cut=True, + ), + ) + + +def get_operation_from_request_payload( + cg: ChunkedGraph, + payload: Union[list, dict], + split: bool, + *, + mincut: bool = True, + user_id: str = None, +) -> Tuple[np.ndarray, np.ndarray, GraphEditOperation]: + if user_id is None: + user_id = USER_ID + if split: + return _parse_split_payload(cg, user_id, payload, mincut=mincut) + return _parse_merge_payload(cg, user_id, payload) diff --git a/pychunkedgraph/debug/existence_test.py b/pychunkedgraph/debug/existence_test.py new file mode 100644 index 000000000..757d3d542 --- /dev/null +++ b/pychunkedgraph/debug/existence_test.py @@ -0,0 +1,78 @@ +import os +from datetime import datetime +import numpy as np + +from pychunkedgraph.graph import chunkedgraph +from pychunkedgraph.graph import attributes + +#os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/home/svenmd/.cloudvolume/secrets/google-secret.json" + +layer = 2 +n_chunks = 100 +n_segments_per_chunk = 200 +# timestamp = datetime.datetime.fromtimestamp(1588875769) +timestamp = datetime.utcnow() + +cg = chunkedgraph.ChunkedGraph(graph_id="pinky_nf_v2") + +np.random.seed(42) + +node_ids = [] +for _ in range(n_chunks): + c_x = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][0]) + c_y = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][1]) + c_z = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][2]) + + chunk_id = cg.get_chunk_id(layer=layer, x=c_x, y=c_y, z=c_z) + + max_segment_id = cg.get_segment_id(cg.id_client.get_max_node_id(chunk_id)) + + if max_segment_id < 10: + continue + + segment_ids = np.random.randint(1, max_segment_id, n_segments_per_chunk) + + for segment_id in segment_ids: + node_ids.append(cg.get_node_id(np.uint64(segment_id), np.uint64(chunk_id))) + +rows = cg.client.read_nodes(node_ids=node_ids, end_time=timestamp, + properties=attributes.Hierarchy.Parent) +valid_node_ids = [] +non_valid_node_ids = [] +for k in rows.keys(): + if len(rows[k]) > 0: + valid_node_ids.append(k) + else: + non_valid_node_ids.append(k) + +roots = cg.get_roots(valid_node_ids, time_stamp=timestamp) + +roots = [] +try: + roots = cg.get_roots(valid_node_ids) + assert len(roots) == len(valid_node_ids) + print(f"ALL {len(roots)} have been successful!") +except: + print("At least one node failed. Checking nodes one by one now") + +if len(roots) != len(valid_node_ids): + log_dict = {} + success_dict = {} + for node_id in valid_node_ids: + try: + root = cg.get_root(node_id, time_stamp=timestamp) + print(f"Success: {node_id} from chunk {cg.get_chunk_id(node_id)}") + success_dict[node_id] = True + except Exception as e: + print(f"{node_id} from chunk {cg.get_chunk_id(node_id)} failed with {e}") + success_dict[node_id] = False + + t_id = node_id + + while t_id is not None: + last_working_chunk = cg.get_chunk_id(t_id) + t_id = cg.get_parent(t_id) + + print(f"Failed on layer {cg.get_chunk_layer(last_working_chunk)} in chunk {last_working_chunk}") + log_dict[node_id] = last_working_chunk + diff --git a/pychunkedgraph/debug/family_test.py b/pychunkedgraph/debug/family_test.py new file mode 100644 index 000000000..198351e74 --- /dev/null +++ b/pychunkedgraph/debug/family_test.py @@ -0,0 +1,54 @@ +import os +from datetime import datetime +import numpy as np + +from pychunkedgraph.graph import chunkedgraph +from pychunkedgraph.graph import attributes + +# os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/home/svenmd/.cloudvolume/secrets/google-secret.json" + +layers = [2, 3, 4, 5, 6, 7] +n_chunks = 10 +n_segments_per_chunk = 200 +# timestamp = datetime.datetime.fromtimestamp(1588875769) +timestamp = datetime.utcnow() + +cg = chunkedgraph.ChunkedGraph(graph_id="pinky_nf_v2") + +np.random.seed(42) + +node_ids = [] + +for layer in layers: + for _ in range(n_chunks): + c_x = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][0]) + c_y = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][1]) + c_z = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][2]) + + chunk_id = cg.get_chunk_id(layer=layer, x=c_x, y=c_y, z=c_z) + + max_segment_id = cg.get_segment_id(cg.id_client.get_max_node_id(chunk_id)) + + if max_segment_id < 10: + continue + + segment_ids = np.random.randint(1, max_segment_id, n_segments_per_chunk) + + for segment_id in segment_ids: + node_ids.append(cg.get_node_id(np.uint64(segment_id), np.uint64(chunk_id))) + +rows = cg.client.read_nodes(node_ids=node_ids, end_time=timestamp, + properties=attributes.Hierarchy.Parent) +valid_node_ids = [] +non_valid_node_ids = [] +for k in rows.keys(): + if len(rows[k]) > 0: + valid_node_ids.append(k) + else: + non_valid_node_ids.append(k) + +parents = cg.get_parents(valid_node_ids, time_stamp=timestamp) +children_dict = cg.get_children(parents) + +for child, parent in zip(valid_node_ids, parents): + assert child in children_dict[parent] \ No newline at end of file diff --git a/pychunkedgraph/debug/utils.py b/pychunkedgraph/debug/utils.py new file mode 100644 index 000000000..179f50aef --- /dev/null +++ b/pychunkedgraph/debug/utils.py @@ -0,0 +1,43 @@ +import numpy as np + +from ..graph import ChunkedGraph +from ..graph.utils.basetypes import NODE_ID + + +def print_attrs(d): + for k, v in d.items(): + try: + print(k.key) + except: + print(k) + try: + print(v[:2], "...") if type(v) is np.ndarray and len(v) > 2 else print(v) + except: + print(v) + + +def print_node( + cg: ChunkedGraph, + node: NODE_ID, + indent: int = 0, + stop_layer: int = 2, +) -> None: + children = cg.get_children(node) + print(f"{' ' * indent}{node}[{len(children)}]") + if cg.get_chunk_layer(node) <= stop_layer: + return + for child in children: + print_node(cg, child, indent=indent + 1, stop_layer=stop_layer) + + +def get_l2children(cg: ChunkedGraph, node: NODE_ID) -> np.ndarray: + nodes = np.array([node], dtype=NODE_ID) + layers = cg.get_chunk_layers(nodes) + assert np.all(layers > 2), "nodes must be at layers > 2" + l2children = [] + while nodes.size: + children = cg.get_children(nodes, flatten=True) + layers = cg.get_chunk_layers(children) + l2children.append(children[layers == 2]) + nodes = children[layers > 2] + return np.concatenate(l2children) diff --git a/pychunkedgraph/edge_gen/Dockerfile b/pychunkedgraph/edge_gen/Dockerfile deleted file mode 100644 index 33e07707f..000000000 --- a/pychunkedgraph/edge_gen/Dockerfile +++ /dev/null @@ -1,38 +0,0 @@ -FROM python:3-alpine - -RUN apk add --no-cache --virtual .build-deps \ - curl \ - libc6-compat \ - git \ - gcc \ - g++ \ - linux-headers \ - jpeg-dev \ - mariadb-dev \ - && apk add --no-cache \ - libstdc++ \ - libjpeg-turbo \ - mariadb-connector-c \ - \ - # separate numpy install fixes cloudvolume bug - && pip install --no-cache-dir \ - numpy \ - && pip install --no-cache-dir --upgrade \ - cloud-volume \ - tenacity \ - networkx \ - google-cloud-bigtable \ - zstandard \ - mysqlclient \ - && mkdir /root/.cloudvolume \ - && ln -s /secrets /root/.cloudvolume/secrets \ - \ - && git clone "https://github.com/seung-lab/pychunkedgraph.git" /usr/local/pychunkedgraph \ - && rm -rf /usr/local/pychunkedgraph/.git \ - && apk del .build-deps \ - && find /usr/local -depth \ - \( \ - \( -type d -a \( -name __pycache__ \) \) \ - -o \ - \( -type f -a \( -name '*.pyc' -o -name '*.pyo' \) \) \ - \) -exec rm -rf '{}' + diff --git a/pychunkedgraph/edge_gen/edgetask.py b/pychunkedgraph/edge_gen/edgetask.py deleted file mode 100644 index 478b03434..000000000 --- a/pychunkedgraph/edge_gen/edgetask.py +++ /dev/null @@ -1,597 +0,0 @@ -import json -import os -import re -import sys -from copy import deepcopy -from functools import lru_cache -from itertools import chain -from operator import itemgetter -from typing import Iterable, Mapping, Tuple, Union - -from cloudvolume import CloudVolume, Storage - -import MySQLdb -import numpy as np -import zstandard as zstd - -sys.path.insert(0, os.path.join(sys.path[0], '..')) -from backend import chunkedgraph # noqa - -UINT64_ZERO = np.uint64(0) -UINT64_ONE = np.uint64(1) - - -class EdgeTask: - def __init__(self, - cgraph: chunkedgraph.ChunkedGraph, - mysql_conn: any, - agglomeration_input: CloudVolume, - watershed_input: CloudVolume, - regiongraph_input: Storage, - regiongraph_output: Storage, - regiongraph_chunksize: Tuple[int, int, int], - roi: Tuple[slice, slice, slice]): - self.__cgraph = cgraph - self.__mysql_conn = mysql_conn - self.__watershed = { - "cv_input": watershed_input, - "original": np.array([], dtype=np.uint64, ndmin=3), - "relabeled": np.array([], dtype=np.uint64, ndmin=3), - "rg2cg_complete": {}, - "rg2cg_boundary": {} - } - self.__agglomeration = { - "cv": agglomeration_input, - "original": np.array([], dtype=np.uint64, ndmin=3) - } - self.__regiongraph = { - "storage_in": regiongraph_input, - "storage_out": regiongraph_output, - "edges": {}, - "chunksize": regiongraph_chunksize, - "offset": self.__watershed["cv_input"].voxel_offset, - "maxlevel": int(np.ceil(np.log2(np.max(np.floor_divide( - self.__watershed["cv_input"].volume_size, regiongraph_chunksize))))) - } - self.__roi = roi - self.__watershed["original"] = self.__watershed["cv_input"][self.__roi] - self.__watershed["relabeled"] = np.empty_like(self.__watershed["original"]) - self.__agglomeration["original"] = \ - self.__agglomeration["cv"][self.__roi] - - def execute(self): - self.__relabel_cutout() - - self.__compute_cutout_regiongraph() - - return - - def get_relabeled_watershed(self): - return self.__watershed["relabeled"][0:-1, 0:-1, 0:-1, :] - - def __load_rg_chunkhierarchy_affinities(self): - """ - Collect all weighted edges from the Region Graph chunk hierarchy - within the ROI. - """ - - # Convert ROI (in voxel) to Region Graph chunk indices - chunk_range = tuple(map( - lambda x: - np.floor_divide( - np.maximum(0, np.subtract(x, self.__regiongraph["offset"])), - self.__regiongraph["chunksize"]), - ((self.__roi[0].start, self.__roi[1].start, self.__roi[2].start), - (self.__roi[0].stop, self.__roi[1].stop, self.__roi[2].stop)) - )) - - # TODO: Possible speedup by skipping high level chunks that don't - # intersect with ROI - edges = [] - for l in range(self.__regiongraph["maxlevel"] + 1): - for x in range(chunk_range[0][0], chunk_range[1][0] + 1): - for y in range(chunk_range[0][1], chunk_range[1][1] + 1): - for z in range(chunk_range[0][2], chunk_range[1][2] + 1): - print("Loading layer %i: (%i,%i,%i)" % (l, x, y, z)) - chunk_path = "edges_%i_%i_%i_%i.data.zst" % (l, x, y, z) - edges.append(load_rg_chunk_affinities( - self.__regiongraph["storage_in"], chunk_path) - ) - - chunk_range = (chunk_range[0] // 2, chunk_range[1] // 2) - - print("Converting to Set") - return {e.item()[0:2]: e for e in chain(*edges)} - - def __load_cutout_labels_from_db(self): - chunks_to_fetch = [] - for x in range(self.__roi[0].start, self.__roi[0].stop, self.__cgraph.chunk_size[0]): - for y in range(self.__roi[1].start, self.__roi[1].stop, self.__cgraph.chunk_size[1]): - for z in range(self.__roi[2].start, self.__roi[2].stop, self.__cgraph.chunk_size[2]): - chunks_to_fetch.append(self.__cgraph.get_chunk_id_from_coord(1, x, y, z)) - - self.__mysql_conn.query("SELECT id, edges FROM chunkedges WHERE id IN (%s);" % ",".join(str(x) for x in chunks_to_fetch)) - res = self.__mysql_conn.store_result() - - chunk_labels = {x: {} for x in chunks_to_fetch} - for row in res.fetch_row(maxrows=0): - edges_iter = iter(np.frombuffer(row[1], dtype=np.uint64)) - chunk_labels[row[0]] = dict(zip(edges_iter, edges_iter)) - - self.__watershed["rg2cg_boundary"] = chunk_labels - self.__watershed["rg2cg_complete"] = deepcopy(self.__watershed["rg2cg_boundary"]) - - def __save_cutout_labels_to_db(self): - self.__mysql_conn.query("START TRANSACTION;") - - chunk_labels = self.__watershed["rg2cg_boundary"] - for chunk_id, mappings in chunk_labels.items(): - if len(mappings) > 0: - flat_binary_mapping = np.fromiter((item for k in mappings for item in (k, mappings[k])), dtype=np.uint64).tobytes() - flat_binary_mapping_escaped = self.__mysql_conn.escape_string(flat_binary_mapping) - self.__mysql_conn.query(b"INSERT INTO chunkedges (id, edges) VALUES (%i, \"%s\") ON DUPLICATE KEY UPDATE edges = VALUES(edges);" % (chunk_id, flat_binary_mapping_escaped)) - - self.__mysql_conn.query("COMMIT;") - - def __relabel_cutout(self): - # Load existing labels of center + neighboring chunks - self.__load_cutout_labels_from_db() - assigned_node_ids = {node_id for chunk_edges in self.__watershed["rg2cg_complete"].values() for node_id in chunk_edges.values()} - - def relabel_chunk(chunk_id: np.uint64, view_range: Tuple[slice, slice, slice]): - next_segment_id = UINT64_ONE - - original = np.nditer( - self.__watershed["original"][view_range], flags=['multi_index']) - relabeled = np.nditer(self.__watershed["relabeled"][view_range], flags=['multi_index'], op_flags=['writeonly']) - - print("Starting Loop for chunk %i" % chunk_id) - while not original.finished: - original_val = np.uint64(original[0]) - - if original_val == UINT64_ZERO: - # Don't relabel cell boundary (ID 0) - relabeled[0] = UINT64_ZERO - elif original_val in self.__watershed["rg2cg_complete"][chunk_id]: - # Already encountered this ID before. - relabeled[0] = relabeled_val = self.__watershed["rg2cg_complete"][chunk_id][original_val] - if original.multi_index[0] == 0 or \ - original.multi_index[1] == 0 or \ - original.multi_index[2] == 0: - self.__watershed["rg2cg_boundary"][chunk_id][original_val] = relabeled_val - else: - # Find new, unused node ID for this chunk. - while self.__cgraph.get_node_id( - segment_id=next_segment_id, - chunk_id=chunk_id) in assigned_node_ids: - next_segment_id += UINT64_ONE - - relabeled_val = self.__cgraph.get_node_id( - segment_id=next_segment_id, - chunk_id=chunk_id) - - relabeled[0] = relabeled_val - next_segment_id += UINT64_ONE - assigned_node_ids.add(relabeled_val) - - self.__watershed["rg2cg_complete"][chunk_id][original_val] = relabeled_val - if original.multi_index[0] == 0 or \ - original.multi_index[1] == 0 or \ - original.multi_index[2] == 0: - self.__watershed["rg2cg_boundary"][chunk_id][original_val] = relabeled_val - - original.iternext() - relabeled.iternext() - - for x_start in (0, int(self.__cgraph.chunk_size[0])): - for y_start in (0, int(self.__cgraph.chunk_size[1])): - for z_start in (0, int(self.__cgraph.chunk_size[2])): - x_end = x_start + int(self.__cgraph.chunk_size[0]) - y_end = y_start + int(self.__cgraph.chunk_size[1]) - z_end = z_start + int(self.__cgraph.chunk_size[2]) - - chunk_id = self.__cgraph.get_chunk_id_from_coord( - layer=1, - x=self.__roi[0].start + x_start, - y=self.__roi[1].start + y_start, - z=self.__roi[2].start + z_start) - - relabel_chunk(chunk_id, (slice(x_start, x_end), slice(y_start, y_end), slice(z_start, z_end))) - - self.__save_cutout_labels_to_db() - - def __compute_cutout_regiongraph(self): - edges_center_connected = np.array([]) - edges_center_disconnected = np.array([]) - isolated_sv = np.array([]) - edges_xplus_connected = np.array([]) - edges_xplus_disconnected = np.array([]) - edges_xplus_unbreakable = np.array([]) - edges_yplus_connected = np.array([]) - edges_yplus_disconnected = np.array([]) - edges_yplus_unbreakable = np.array([]) - edges_zplus_connected = np.array([]) - edges_zplus_disconnected = np.array([]) - edges_zplus_unbreakable = np.array([]) - - if np.any(self.__watershed["original"]): - # Download all region graph edges covering this part of the dataset - regiongraph_edges = self.__load_rg_chunkhierarchy_affinities() - - print("Calculating RegionGraph...") - - original = self.__watershed["original"] - agglomeration = self.__agglomeration["original"] - - # Shortcut to Original -> Relabeled supervoxel lookup table for the - # center chunk - rg2cg_center = self.__watershed["rg2cg_complete"][ - self.__cgraph.get_chunk_id_from_coord( - layer=1, - x=self.__roi[0].start, - y=self.__roi[1].start, - z=self.__roi[2].start)] - - # Original -> Relabeled supervoxel lookup table for chunk in X+ dir - rg2cg_xplus = self.__watershed["rg2cg_complete"][ - self.__cgraph.get_chunk_id_from_coord( - layer=1, - x=self.__roi[0].start + int(self.__cgraph.chunk_size[0]), - y=self.__roi[1].start, - z=self.__roi[2].start)] - - # Original -> Relabeled supervoxel lookup table for chunk in Y+ dir - rg2cg_yplus = self.__watershed["rg2cg_complete"][ - self.__cgraph.get_chunk_id_from_coord( - layer=1, - x=self.__roi[0].start, - y=self.__roi[1].start + int(self.__cgraph.chunk_size[1]), - z=self.__roi[2].start)] - - # Original -> Relabeled supervoxel lookup table for chunk in Z+ dir - rg2cg_zplus = self.__watershed["rg2cg_complete"][ - self.__cgraph.get_chunk_id_from_coord( - layer=1, - x=self.__roi[0].start, - y=self.__roi[1].start, - z=self.__roi[2].start + int(self.__cgraph.chunk_size[2]))] - - # Mask unsegmented voxel (ID=0) and voxel not at a supervoxel - # boundary in X-direction - sv_boundaries_x = \ - (original[:-1, :, :] != UINT64_ZERO) & (original[1:, :, :] != UINT64_ZERO) & \ - (original[:-1, :, :] != original[1:, :, :]) - - # Mask voxel that are not at an agglomeration boundary in X-direction - agg_boundaries_x = (agglomeration[:-1, :, :] == agglomeration[1:, :, :]) - - # Mask unsegmented voxel (ID=0) and voxel not at a supervoxel - # boundary in Y-direction - sv_boundaries_y = \ - (original[:, :-1, :] != UINT64_ZERO) & (original[:, 1:, :] != UINT64_ZERO) & \ - (original[:, :-1, :] != original[:, 1:, :]) - - # Mask voxel that are not at an agglomeration boundary in Y-direction - agg_boundaries_y = (agglomeration[:, :-1, :] == agglomeration[:, 1:, :]) - - # Mask unsegmented voxel (ID=0) and voxel not at a supervoxel - # boundary in Z-direction - sv_boundaries_z = \ - (original[:, :, :-1] != UINT64_ZERO) & (original[:, :, 1:] != UINT64_ZERO) & \ - (original[:, :, :-1] != original[:, :, 1:]) - - # Mask voxel that are not at an agglomeration boundary in Z-direction - agg_boundaries_z = (agglomeration[:, :, :-1] == agglomeration[:, :, 1:]) - - # Center Chunk: - # Collect all unique pairs of adjacent supervoxel IDs from the original - # watershed labeling that are part of the same agglomeration. - # Note that edges are sorted (lower supervoxel ID comes first). - edges_center_connected = {x if x[0] < x[1] else (x[1], x[0]) for x in chain( - zip(original[:-2, :-1, :-1][sv_boundaries_x[:-1, :-1, :-1] & agg_boundaries_x[:-1, :-1, :-1]], - original[1:-1, :-1, :-1][sv_boundaries_x[:-1, :-1, :-1] & agg_boundaries_x[:-1, :-1, :-1]]), - zip(original[:-1, :-2, :-1][sv_boundaries_y[:-1, :-1, :-1] & agg_boundaries_y[:-1, :-1, :-1]], - original[:-1, 1:-1, :-1][sv_boundaries_y[:-1, :-1, :-1] & agg_boundaries_y[:-1, :-1, :-1]]), - zip(original[:-1, :-1, :-2][sv_boundaries_z[:-1, :-1, :-1] & agg_boundaries_z[:-1, :-1, :-1]], - original[:-1, :-1, 1:-1][sv_boundaries_z[:-1, :-1, :-1] & agg_boundaries_z[:-1, :-1, :-1]]))} - - # Look up the affinity information for each edge and replace - # original supervoxel IDs with relabeled IDs - if edges_center_connected: - edges_center_connected = np.array([ - (*sorted(itemgetter(x[0], x[1])(rg2cg_center)), x[2], x[3]) - for x in [regiongraph_edges[e] for e in edges_center_connected] - ], dtype='uint64, uint64, float32, uint64') - else: - edges_center_connected = np.array([], dtype='uint64, uint64, float32, uint64') - - # Collect all unique pairs of adjacent supervoxel IDs from the original - # watershed labeling that are NOT part of the same agglomeration. - edges_center_disconnected = {x if x[0] < x[1] else (x[1], x[0]) for x in chain( - zip(original[:-2, :-1, :-1][sv_boundaries_x[:-1, :-1, :-1] & ~agg_boundaries_x[:-1, :-1, :-1]], - original[1:-1, :-1, :-1][sv_boundaries_x[:-1, :-1, :-1] & ~agg_boundaries_x[:-1, :-1, :-1]]), - zip(original[:-1, :-2, :-1][sv_boundaries_y[:-1, :-1, :-1] & ~agg_boundaries_y[:-1, :-1, :-1]], - original[:-1, 1:-1, :-1][sv_boundaries_y[:-1, :-1, :-1] & ~agg_boundaries_y[:-1, :-1, :-1]]), - zip(original[:-1, :-1, :-2][sv_boundaries_z[:-1, :-1, :-1] & ~agg_boundaries_z[:-1, :-1, :-1]], - original[:-1, :-1, 1:-1][sv_boundaries_z[:-1, :-1, :-1] & ~agg_boundaries_z[:-1, :-1, :-1]]))} - - # Look up the affinity information for each edge and replace - # original supervoxel IDs with relabeled IDs - if edges_center_disconnected: - edges_center_disconnected = np.array([ - (*sorted(itemgetter(x[0], x[1])(rg2cg_center)), x[2], x[3]) - for x in [regiongraph_edges[e] for e in edges_center_disconnected] - ], dtype='uint64, uint64, float32, uint64') - else: - edges_center_disconnected = np.array([], dtype='uint64, uint64, float32, uint64') - - # Check if there are supervoxel that are not connected to any other - # supervoxel - surrounded by ID 0 - isolated_sv = set(rg2cg_center.values()) - for e in chain(edges_center_connected, edges_center_disconnected): - isolated_sv.discard(e[0]) - isolated_sv.discard(e[1]) - isolated_sv = np.array(list(isolated_sv), dtype=np.uint64) - - # XPlus Chunk: - # Collect edges between center chunk and the chunk in X+ direction. - # Slightly different approach because the relabeling lookup needs - # to be done for two different dictionaries. Slower, but fast enough - # due to far fewer edges near the boundary. - # Node ID layout guarantees that center chunk IDs are always smaller - # than IDs of positive neighboring chunks. - edges_xplus_connected = np.array(list({ - (rg2cg_center[x[0]], - rg2cg_xplus[x[1]], - *regiongraph_edges[x if x[0] < x[1] else (x[1], x[0])].item()[2:]) for x in zip( - original[-2:-1, :-1, :-1][sv_boundaries_x[-1:, :-1, :-1] & agg_boundaries_x[-1:, :-1, :-1]], - original[-1:, :-1, :-1][sv_boundaries_x[-1:, :-1, :-1] & agg_boundaries_x[-1:, :-1, :-1]]) - }), dtype='uint64, uint64, float32, uint64') - - edges_xplus_disconnected = np.array(list({ - (rg2cg_center[x[0]], - rg2cg_xplus[x[1]], - *regiongraph_edges[x if x[0] < x[1] else (x[1], x[0])].item()[2:]) for x in zip( - original[-2:-1, :-1, :-1][sv_boundaries_x[-1:, :-1, :-1] & ~agg_boundaries_x[-1:, :-1, :-1]], - original[-1:, :-1, :-1][sv_boundaries_x[-1:, :-1, :-1] & ~agg_boundaries_x[-1:, :-1, :-1]]) - }), dtype='uint64, uint64, float32, uint64') - - # Unbreakable edges (caused by relabeling and chunking) don't have - # sum of area or affinity values - edges_xplus_unbreakable = np.array(list({ - (rg2cg_center[x], rg2cg_xplus[x]) for x in np.unique( - original[-2:-1, :-1, :-1][(original[-2:-1, :-1, :-1] != UINT64_ZERO) & - (original[-2:-1, :-1, :-1] == original[-1:, :-1, :-1])]) - }), dtype='uint64, uint64') - - # YPlus Chunk: - # Collect edges between center chunk and the chunk in Y+ direction. - edges_yplus_connected = np.array(list({ - (rg2cg_center[x[0]], - rg2cg_yplus[x[1]], - *regiongraph_edges[x if x[0] < x[1] else (x[1], x[0])].item()[2:]) for x in zip( - original[:-1, -2:-1, :-1][sv_boundaries_y[:-1, -1:, :-1] & agg_boundaries_y[:-1, -1:, :-1]], - original[:-1, -1:, :-1][sv_boundaries_y[:-1, -1:, :-1] & agg_boundaries_y[:-1, -1:, :-1]]) - }), dtype='uint64, uint64, float32, uint64') - - edges_yplus_disconnected = np.array(list({ - (rg2cg_center[x[0]], - rg2cg_yplus[x[1]], - *regiongraph_edges[x if x[0] < x[1] else (x[1], x[0])].item()[2:]) for x in zip( - original[:-1, -2:-1, :-1][sv_boundaries_y[:-1, -1:, :-1] & ~agg_boundaries_y[:-1, -1:, :-1]], - original[:-1, -1:, :-1][sv_boundaries_y[:-1, -1:, :-1] & ~agg_boundaries_y[:-1, -1:, :-1]]) - }), dtype='uint64, uint64, float32, uint64') - - edges_yplus_unbreakable = np.array(list({ - (rg2cg_center[x], rg2cg_yplus[x]) for x in np.unique( - original[:-1, -2:-1, :-1][(original[:-1, -2:-1, :-1] != UINT64_ZERO) & - (original[:-1, -2:-1, :-1] == original[:-1, -1:, :-1])]) - }), dtype='uint64, uint64') - - # ZPlus Chunk - # Collect edges between center chunk and the chunk in Z+ direction. - edges_zplus_connected = np.array(list({ - (rg2cg_center[x[0]], - rg2cg_zplus[x[1]], - *regiongraph_edges[x if x[0] < x[1] else (x[1], x[0])].item()[2:]) for x in zip( - original[:-1, :-1, -2:-1][sv_boundaries_z[:-1, :-1, -1:] & agg_boundaries_z[:-1, :-1, -1:]], - original[:-1, :-1, -1:][sv_boundaries_z[:-1, :-1, -1:] & agg_boundaries_z[:-1, :-1, -1:]]) - }), dtype='uint64, uint64, float32, uint64') - - edges_zplus_disconnected = np.array(list({ - (rg2cg_center[x[0]], - rg2cg_zplus[x[1]], - *regiongraph_edges[x if x[0] < x[1] else (x[1], x[0])].item()[2:]) for x in zip( - original[:-1, :-1, -2:-1][sv_boundaries_z[:-1, :-1, -1:] & ~agg_boundaries_z[:-1, :-1, -1:]], - original[:-1, :-1, -1:][sv_boundaries_z[:-1, :-1, -1:] & ~agg_boundaries_z[:-1, :-1, -1:]]) - }), dtype='uint64, uint64, float32, uint64') - - edges_zplus_unbreakable = np.array(list({ - (rg2cg_center[x], rg2cg_zplus[x]) for x in np.unique( - original[:-1, :-1, -2:-1][(original[:-1, :-1, -2:-1] != UINT64_ZERO) & - (original[:-1, :-1, -2:-1] == original[:-1, :-1, -1:])]) - }), dtype='uint64, uint64') - else: - print("Fast skipping Regiongraph calculation - empty block") - - # Prepare upload - rg2cg_center_str = slice_to_str( - slice(self.__roi[x].start, - self.__roi[x].start + int(self.__cgraph.chunk_size[x])) for x in range(3)) - - rg2cg_xplus_str = slice_to_str(( - slice(self.__roi[0].start + int(self.__cgraph.chunk_size[0]) - 1, - self.__roi[0].start + int(self.__cgraph.chunk_size[0]) + 1), - slice(self.__roi[1].start, self.__roi[1].start + int(self.__cgraph.chunk_size[1])), - slice(self.__roi[2].start, self.__roi[2].start + int(self.__cgraph.chunk_size[2])))) - - rg2cg_yplus_str = slice_to_str(( - slice(self.__roi[0].start, self.__roi[0].start + int(self.__cgraph.chunk_size[0])), - slice(self.__roi[1].start + int(self.__cgraph.chunk_size[1]) - 1, - self.__roi[1].start + int(self.__cgraph.chunk_size[1]) + 1), - slice(self.__roi[2].start, self.__roi[2].start + int(self.__cgraph.chunk_size[2])))) - - rg2cg_zplus_str = slice_to_str(( - slice(self.__roi[0].start, self.__roi[0].start + int(self.__cgraph.chunk_size[0])), - slice(self.__roi[1].start, self.__roi[1].start + int(self.__cgraph.chunk_size[1])), - slice(self.__roi[2].start + int(self.__cgraph.chunk_size[2]) - 1, - self.__roi[2].start + int(self.__cgraph.chunk_size[2]) + 1))) - - print("Uploading edges") - self.__regiongraph["storage_out"].put_files( - files=[(rg2cg_center_str + '_connected.bin', edges_center_connected.tobytes()), - (rg2cg_center_str + '_disconnected.bin', edges_center_disconnected.tobytes()), - (rg2cg_center_str + '_isolated.bin', isolated_sv.tobytes()), - (rg2cg_xplus_str + '_connected.bin', edges_xplus_connected.tobytes()), - (rg2cg_xplus_str + '_disconnected.bin', edges_xplus_disconnected.tobytes()), - (rg2cg_xplus_str + '_unbreakable.bin', edges_xplus_unbreakable.tobytes()), - (rg2cg_yplus_str + '_connected.bin', edges_yplus_connected.tobytes()), - (rg2cg_yplus_str + '_disconnected.bin', edges_yplus_disconnected.tobytes()), - (rg2cg_yplus_str + '_unbreakable.bin', edges_yplus_unbreakable.tobytes()), - (rg2cg_zplus_str + '_connected.bin', edges_zplus_connected.tobytes()), - (rg2cg_zplus_str + '_disconnected.bin', edges_zplus_disconnected.tobytes()), - (rg2cg_zplus_str + '_unbreakable.bin', edges_zplus_unbreakable.tobytes())], - content_type='application/octet-stream') - print("Done") - - -@lru_cache(maxsize=32) -def load_rg_chunk_affinities(regiongraph_storage: Storage, chunk_path: str) -> np.ndarray: - """ - Extract weighted supervoxel edges from zstd compressed Region Graph - file `chunk_path`. - The unversioned, custom binary file format shall be called RanStruct, - which, as of 2018-08-03, looks like this: - - struct RanStruct # Little Endian, not aligned -> 56 Byte - segA1::UInt64 - segB1::UInt64 - sum_aff1::Float32 - sum_area1::UInt64 - segA2::UInt64 # same as segA1 - segB2::UInt64 # same as segB1 - sum_aff2::Float32 # same as sum_aff1 - sum_area2::UInt64 # same as sum_area1 - end - - The big top level Region Graph chunks get requested almost every time, - thus the memoization. - """ - - f = regiongraph_storage.get_file(chunk_path) - if not f: - Warning("%s doesn't exist" % chunk_path) - return np.array([], dtype='uint64, uint64, float32, uint64') - - dctx = zstd.ZstdDecompressor() - decompressed = dctx.decompress(f) - - buf = np.frombuffer(decompressed, dtype='uint64, uint64, float32, uint64') - return np.lib.stride_tricks.as_strided( - buf, - shape=tuple(x//2 for x in buf.shape), - strides=tuple(x*2 for x in buf.strides), - writeable=False - ) - - -def str_to_slice(slice_str: str) -> Tuple[slice, slice, slice]: - match = re.match(r"(\d+)-(\d+)_(\d+)-(\d+)_(\d+)-(\d+)", slice_str) - return (slice(int(match.group(1)), int(match.group(2))), - slice(int(match.group(3)), int(match.group(4))), - slice(int(match.group(5)), int(match.group(6)))) - - -def slice_to_str(slices: Union[slice, Iterable[slice]]) -> str: - if isinstance(slices, slice): - return "%d-%d" % (slices.start, slices.stop) - else: - return '_'.join(map(slice_to_str, slices)) - - -def run_task_bundle(settings: Mapping, roi: Tuple[slice, slice, slice]): - # Remember: DB must be cleared before starting a whole new run - with open("/secrets/mysql") as passwd: - mysql_conn = MySQLdb.connect( - host=settings["mysql"]["host"], - user=settings["mysql"]["user"], - db=settings["mysql"]["db"], - passwd=passwd.read().strip() - ) - - cgraph = chunkedgraph.ChunkedGraph( - table_id=settings["chunkedgraph"]["table_id"], - instance_id=settings["chunkedgraph"]["instance_id"] - ) - - # Things to check: - # - Agglomeration and Input Watershed have the same offset/size - # - Taskbundle Offset and ROI is a multiple of cgraph.chunksize - # - Output Watershed chunksize must be a multiple of cgraph.chunksize - - agglomeration_input = CloudVolume( - settings["layers"]["agglomeration_path_input"], bounded=False) - watershed_input = CloudVolume( - settings["layers"]["watershed_path_input"], bounded=False) - watershed_output = CloudVolume( - settings["layers"]["watershed_path_output"], bounded=False, autocrop=True) - regiongraph_input = Storage( - settings["regiongraph"]["regiongraph_path_input"]) - regiongraph_output = Storage( - settings["regiongraph"]["regiongraph_path_output"]) - regiongraph_chunksize = tuple(settings["regiongraph"]["chunksize"]) - - chunkgraph_chunksize = np.array(cgraph.chunk_size, dtype=np.int) - output_watershed_chunksize = np.array(watershed_output.underlying, dtype=np.int) - outer_chunksize = np.maximum(chunkgraph_chunksize, output_watershed_chunksize, dtype=np.int) - - # Iterate through TaskBundle using a minimal chunk size that is a multiple - # of the output watershed chunk size and the Chunked Graph chunk size. - for ox in range(roi[0].start, roi[0].stop, outer_chunksize[0]): - for oy in range(roi[1].start, roi[1].stop, outer_chunksize[1]): - for oz in range(roi[2].start, roi[2].stop, outer_chunksize[2]): - - watershed_output_buffer = np.zeros((*outer_chunksize, 1), dtype=np.uint64) - - # Iterate through ChunkGraph chunk-sized tasks: - for ix_start in range(0, outer_chunksize[0], chunkgraph_chunksize[0]): - for iy_start in range(0, outer_chunksize[1], chunkgraph_chunksize[1]): - for iz_start in range(0, outer_chunksize[2], chunkgraph_chunksize[2]): - ix_end = ix_start + chunkgraph_chunksize[0] - iy_end = iy_start + chunkgraph_chunksize[1] - iz_end = iz_start + chunkgraph_chunksize[2] - - # One voxel overlap in each dimension to get - # consistent labeling across chunks - edgetask_roi = (slice(ox + ix_start, ox + ix_end + 1), - slice(oy + iy_start, oy + iy_end + 1), - slice(oz + iz_start, oz + iz_end + 1)) - - edgetask = EdgeTask( - cgraph=cgraph, - mysql_conn=mysql_conn, - agglomeration_input=agglomeration_input, - watershed_input=watershed_input, - regiongraph_input=regiongraph_input, - regiongraph_output=regiongraph_output, - regiongraph_chunksize=regiongraph_chunksize, - roi=edgetask_roi - ) - edgetask.execute() - - # Write relabeled ChunkGraph chunk to (possibly larger) - # watershed-chunk aligned buffer - watershed_output_buffer[ix_start:ix_end, - iy_start:iy_end, - iz_start:iz_end, :] = \ - edgetask.get_relabeled_watershed() - - watershed_output[ox:ox + outer_chunksize[0], - oy:oy + outer_chunksize[1], - oz:oz + outer_chunksize[2], :] = \ - watershed_output_buffer - - -if __name__ == "__main__": - params = json.loads(sys.argv[1]) - run_task_bundle(params, str_to_slice(sys.argv[2])) diff --git a/pychunkedgraph/edge_gen/requirements.txt b/pychunkedgraph/edge_gen/requirements.txt deleted file mode 100644 index a35e3b27a..000000000 --- a/pychunkedgraph/edge_gen/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -cloud-volume -mysqlclient -zstandard diff --git a/pychunkedgraph/examples/__init__.py b/pychunkedgraph/examples/__init__.py deleted file mode 100644 index 5223e6144..000000000 --- a/pychunkedgraph/examples/__init__.py +++ /dev/null @@ -1,76 +0,0 @@ -from flask import Flask -from flask.logging import default_handler -from flask_cors import CORS -import sys -import logging -import os -import time -import json -import numpy as np -import datetime -from pychunkedgraph.app import config -import redis -from rq import Queue - -from pychunkedgraph.examples.parallel_test.main import init_parallel_test_cmds -from pychunkedgraph.meshing.meshing_test_temp import init_mesh_cmds - -# from pychunkedgraph.app import app_blueprint -from pychunkedgraph.app import cg_app_blueprint, meshing_app_blueprint -from pychunkedgraph.logging import jsonformatter -# from pychunkedgraph.app import manifest_app_blueprint -os.environ['TRAVIS_BRANCH'] = "IDONTKNOWWHYINEEDTHIS" - - -class CustomJsonEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, np.ndarray): - return obj.tolist() - elif isinstance(obj, datetime.datetime): - return obj.__str__() - return json.JSONEncoder.default(self, obj) - - -def create_example_app(test_config=None): - app = Flask(__name__) - app.json_encoder = CustomJsonEncoder - - configure_app(app) - - app.register_blueprint(cg_app_blueprint.bp) - app.register_blueprint(meshing_app_blueprint.bp) - # app.register_blueprint(manifest_app_blueprint.bp) - - with app.app_context(): - init_parallel_test_cmds(app) - init_mesh_cmds(app) - - return app - - -def configure_app(app): - # Load logging scheme from config.py - app_settings = os.getenv('APP_SETTINGS') - if not app_settings: - app.config.from_object(config.BaseConfig) - else: - app.config.from_object(app_settings) - - - # Configure logging - # handler = logging.FileHandler(app.config['LOGGING_LOCATION']) - handler = logging.StreamHandler(sys.stdout) - handler.setLevel(app.config['LOGGING_LEVEL']) - formatter = jsonformatter.JsonFormatter( - fmt=app.config['LOGGING_FORMAT'], - datefmt=app.config['LOGGING_DATEFORMAT']) - formatter.converter = time.gmtime - handler.setFormatter(formatter) - app.logger.removeHandler(default_handler) - app.logger.addHandler(handler) - app.logger.setLevel(app.config['LOGGING_LEVEL']) - app.logger.propagate = False - - if app.config['USE_REDIS_JOBS']: - app.redis = redis.Redis.from_url(app.config['REDIS_URL']) - app.test_q = Queue('test' ,connection=app.redis) \ No newline at end of file diff --git a/pychunkedgraph/examples/parallel_test/main.py b/pychunkedgraph/examples/parallel_test/main.py deleted file mode 100644 index 18aef6c6b..000000000 --- a/pychunkedgraph/examples/parallel_test/main.py +++ /dev/null @@ -1,37 +0,0 @@ -import click -import redis - -from flask import current_app -from flask.cli import AppGroup -from pychunkedgraph.examples.parallel_test.tasks import independent_task - -ingest_cli = AppGroup('parallel') - -def handler(*args, **kwargs): - ''' - Message handler function, called by redis - when a message is received on pubsub channel - ''' - print(args) - print(kwargs) - - -@ingest_cli.command('test') -@click.argument('n', type=int) -@click.argument('size', type=int) -def create_atomic_chunks(n, size): - print(f'Queueing {n} items of size {size} ...') - chunk_pubsub = current_app.redis.pubsub() - chunk_pubsub.subscribe(**{'test-channel': handler}) - - for item_id in range(n): - current_app.test_q.enqueue( - independent_task, - args=(item_id, size)) - - thread = chunk_pubsub.run_in_thread(sleep_time=0.1) - return 'Queued' - - -def init_parallel_test_cmds(app): - app.cli.add_command(ingest_cli) \ No newline at end of file diff --git a/pychunkedgraph/examples/parallel_test/tasks.py b/pychunkedgraph/examples/parallel_test/tasks.py deleted file mode 100644 index 46b956009..000000000 --- a/pychunkedgraph/examples/parallel_test/tasks.py +++ /dev/null @@ -1,21 +0,0 @@ -import os -import time -from flask import current_app -from pychunkedgraph.utils.general import redis_job - -# not a good solution -# figure out how to use app context - -REDIS_HOST = os.environ.get('REDIS_SERVICE_HOST', 'localhost') -REDIS_PORT = os.environ.get('REDIS_SERVICE_PORT', '6379') -REDIS_PASSWORD = os.environ.get('REDIS_PASSWORD', 'dev') -REDIS_URL = f'redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0' - -@redis_job(REDIS_URL, 'test-channel') -def independent_task(chunk_id, chunk_size): - print(f' Working on chunk id: {chunk_id}, size {chunk_size}') - i = 0 - while i < chunk_size: - i += 1 - print('Done') - return chunk_id \ No newline at end of file diff --git a/pychunkedgraph/export/__init__.py b/pychunkedgraph/export/__init__.py new file mode 100644 index 000000000..b48191f33 --- /dev/null +++ b/pychunkedgraph/export/__init__.py @@ -0,0 +1,3 @@ +""" +Modules for exporting data out of storage platform. +""" \ No newline at end of file diff --git a/pychunkedgraph/export/models.py b/pychunkedgraph/export/models.py new file mode 100644 index 000000000..55cd974d4 --- /dev/null +++ b/pychunkedgraph/export/models.py @@ -0,0 +1,82 @@ +from typing import Iterable +from dataclasses import dataclass +from datetime import datetime + + +@dataclass +class OperationLogBase: + """ + Base class for log format. + """ + + id: str + user: str + timestamp: datetime + status: int + roots: Iterable + source_coords: Iterable + sink_coords: Iterable + old_roots: Iterable + old_roots_ts: Iterable + exception: str + operation_ts: datetime + + def __init__(self, **kwargs): + self.id = kwargs.get("id") + self.user = kwargs.get("user") + self.timestamp = kwargs.get("timestamp") + self.status = kwargs.get("status") + self.roots = kwargs.get("roots") + self.source_coords = kwargs.get("source_coords") + self.sink_coords = kwargs.get("sink_coords") + self.old_roots = kwargs.get("old_roots") + self.old_roots_ts = kwargs.get("old_roots_ts") + self.exception = kwargs.get("operation_exception") + # this was added recently + # for older logs assume log_timestamp = operation_timestamp + self.operation_ts = kwargs.get("operation_ts", self.timestamp) + + +class OperationLog: + """ + Wrapper class for log format. + """ + + def __new__(cls, **kwargs): + if "added_edges" in kwargs: + return MergeLog(**kwargs) + return SplitLog(**kwargs) + + +@dataclass +class MergeLog(OperationLogBase): + """Log class for merge operation.""" + + added_edges: Iterable + + def __init__(self, **kwargs): + added_edges = kwargs.pop("added_edges") + super().__init__(**kwargs) + self.added_edges = added_edges + + +@dataclass +class SplitLog(OperationLogBase): + """Log class for split operation.""" + + source_ids: Iterable + sink_ids: Iterable + bb_offset: Iterable + removed_edges: Iterable + + def __init__(self, **kwargs): + source_ids = kwargs.pop("source_ids") + sink_ids = kwargs.pop("sink_ids") + bb_offset = kwargs.pop("bb_offset") + removed_edges = kwargs.pop("removed_edges", []) + super().__init__(**kwargs) + self.source_ids = source_ids + self.sink_ids = sink_ids + self.bb_offset = bb_offset + self.removed_edges = removed_edges + diff --git a/pychunkedgraph/export/operation_logs.py b/pychunkedgraph/export/operation_logs.py new file mode 100644 index 000000000..ec7141ce7 --- /dev/null +++ b/pychunkedgraph/export/operation_logs.py @@ -0,0 +1,81 @@ +from typing import Optional +from typing import Iterable +from datetime import datetime + +from .models import OperationLog +from ..graph import ChunkedGraph +from ..graph.attributes import OperationLogs + + +def parse_attr(attr, val) -> str: + from numpy import ndarray + + try: + if isinstance(val, OperationLogs.StatusCodes): + return (attr.key, val.value) + if isinstance(val, ndarray): + return (attr.key, val.tolist()) + return (attr.key, val) + except AttributeError: + return (attr, val) + + +def get_parsed_logs( + cg: ChunkedGraph, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, +) -> Iterable[OperationLog]: + """Parse logs for compatibility with destination platform.""" + logs = cg.client.read_log_entries(start_time=start_time, end_time=end_time) + result = [] + for _id, _log in logs.items(): + log = {"id": int(_id)} + log["status"] = int(_log.get("operation_status", 0)) + for attr, val in _log.items(): + attr, val = parse_attr(attr, val) + try: + log[attr.decode("utf-8")] = val + except AttributeError: + log[attr] = val + result.append(OperationLog(**log)) + print(f"total raw logs {len(result)}") + return result + + +def get_logs_with_previous_roots( + cg: ChunkedGraph, parsed_logs: Iterable[OperationLog] +) -> Iterable[OperationLog]: + """ + Adds a new entry for new roots' previous IDs. + And timestamps for those roots. + """ + from numpy import array + from numpy import unique + from numpy import concatenate + from ..graph.types import empty_1d + from ..graph.lineage import get_previous_root_ids + from ..graph.utils.basetypes import NODE_ID + + print(f"getting olg roots for {len(parsed_logs)} logs.") + roots = [empty_1d] + for log in parsed_logs: + if len(log.roots): + roots.append(array(log.roots, dtype=NODE_ID)) + roots = concatenate(roots) + # get previous roots for all to avoid multiple network calls + old_roots_d = get_previous_root_ids(cg, roots) + old_roots_all = concatenate([empty_1d, *old_roots_d.values()]) + + old_roots_ts = cg.get_node_timestamps(old_roots_all).tolist() + old_roots_ts_d = dict(zip(old_roots_all, old_roots_ts)) + + for log in parsed_logs: + try: + old_roots = concatenate([old_roots_d[id_] for id_ in log.roots]) + log.old_roots = unique(old_roots).tolist() + log.old_roots_ts = [old_roots_ts_d[id_] for id_ in log.old_roots] + except (ValueError, KeyError): + # if old roots don't exist that means writing was not successful + # NOTE: if status is `WRITE_STARTED` writing is assumed to have failed + pass + return parsed_logs diff --git a/pychunkedgraph/export/to/__init__.py b/pychunkedgraph/export/to/__init__.py new file mode 100644 index 000000000..cadf9871f --- /dev/null +++ b/pychunkedgraph/export/to/__init__.py @@ -0,0 +1 @@ +"""Client code for destination platforms.""" \ No newline at end of file diff --git a/pychunkedgraph/export/to/datastore/__init__.py b/pychunkedgraph/export/to/datastore/__init__.py new file mode 100644 index 000000000..ef7b83436 --- /dev/null +++ b/pychunkedgraph/export/to/datastore/__init__.py @@ -0,0 +1,225 @@ +from typing import Dict +from typing import Optional +from typing import Iterable +from datetime import datetime + +from google.cloud import datastore + +from .config import OperationLogsConfig +from ...models import OperationLog +from ....graph import ChunkedGraph +from ....utils.general import chunked + + +def _create_col_for_each_root(parsed_logs: Iterable[OperationLog]) -> Iterable[Dict]: + """ + Creates a new column for each old and new roots of an operation. + This makes querying easier. For eg, a split operation yields 2 new roots: + old_roots = [123] -> old_root1_col = 123 + new_roots = [124,125] -> new_root1_col = 124, new_root2_col = 125 + """ + from dataclasses import asdict + + count = 0 + + result = [] + for log in parsed_logs: + if log.status == 4: + count += 1 + log_d = asdict(log) + roots = log_d.pop("roots") + old_roots = log_d.pop("old_roots") + old_roots_ts = log_d.pop("old_roots_ts") + + for i, root in enumerate(roots): + log_d[f"root_{i+1}"] = str(root) + + if not (old_roots and old_roots_ts): + result.append(log_d) + continue + + for i, root_info in enumerate(zip(old_roots, old_roots_ts)): + log_d[f"old_root_{i+1}"] = str(root_info[0]) + log_d[f"old_root_{i+1}_ts"] = root_info[1] + result.append(log_d) + print(f"failed {count}") + # if count > int(len(result) * 0.2): + # raise ValueError( + # f"Something's wrong I can feel it, failed count {count}/{len(result)}" + # ) + return result + + +def _nested_lists_to_string(parsed_logs: Iterable[Dict]) -> Iterable[Dict]: + result = [] + for log in parsed_logs: + for k, v in log.items(): + if isinstance(v, list): + log[k] = str(v).replace(" ", "") + result.append(log) + return result + + +def delete_entities(namespace: str, kind: str) -> None: + """ + Deletes all entities of the given kind in given namespace. + Use this only when you need to "clean up". + """ + client = datastore.Client() + + query = client.query(kind=kind, namespace=namespace) + keys = [] + for result in query.fetch(): + keys.append(result.key) + + for chunk in chunked(keys, 500): + client.delete_multi(chunk) + print(f"deleted {len(keys)} entities") + + +def _get_last_timestamp( + client: datastore.Client, + last_export_key: datastore.Key, + operation_logs_config: OperationLogsConfig, +): + from datetime import timedelta + + export_info = client.get(last_export_key) + try: + start_ts = export_info.get(operation_logs_config.EXPORT.LAST_EXPORT_TS) + start_ts = datetime(*list(start_ts.utctimetuple()[:4])) + start_ts -= timedelta(hours=1) + except AttributeError: + start_ts = None + export_ts = datetime.now() + # operation status changes while operation is running + # export is at least an hour behind to ensure all logs have a finalized status. + end_ts = datetime(*list(export_ts.utctimetuple()[:4])) + end_ts -= timedelta(hours=1) + if start_ts == end_ts: + end_ts += timedelta(seconds=1) + return start_ts, end_ts, export_ts + + +def _write_removed_edges(path: str, removed_edges: Iterable) -> None: + """ + Store removed edges of an operation in a bucket. + For some split operations there can be a large number of removed edges + that can't be written to datastore due to size limitations. + """ + if not len(removed_edges): + return + from json import dumps + from cloudfiles import CloudFiles + + cf = CloudFiles(path) + files_to_write = [] + for info in removed_edges: + op_id, edges = info + if not len(edges): + continue + files_to_write.append( + {"content": dumps(edges), "path": f"{op_id}.gz", "compress": "gzip"} + ) + cf.puts(files_to_write) + + +def export_operation_logs( + cg: ChunkedGraph, + start_ts: datetime = None, + end_ts: datetime = None, + namespace: str = None, +) -> int: + """ + Main function to export logs to Datastore. + Returns number of operations that failed to while persisting new IDs. + """ + from os import environ + from .config import DEFAULT_NS + from .config import OperationLogsConfig + from ... import operation_logs + + try: + client = datastore.Client().from_service_account_json( + environ["OPERATION_LOGS_DATASTORE_CREDENTIALS"] + ) + except KeyError: + # use GOOGLE_APPLICATION_CREDENTIALS + # this is usually set to "/root/.cloudvolume/secrets/.json" + client = datastore.Client() + + namespace_ = DEFAULT_NS + if namespace: + namespace_ = namespace + config = OperationLogsConfig(NAMESPACE=namespace_) + + last_export_key = client.key( + config.EXPORT.KIND, cg.graph_id, namespace=config.NAMESPACE + ) + start_ts_, end_ts_, export_ts = _get_last_timestamp(client, last_export_key, config) + if not start_ts: + start_ts = start_ts_ + if not end_ts: + end_ts = end_ts_ + + print(f"getting logs from chunkedgraph {cg.graph_id}") + print(f"start: {start_ts} end: {end_ts}") + logs = operation_logs.get_parsed_logs(cg, start_time=start_ts, end_time=end_ts) + logs = operation_logs.get_logs_with_previous_roots(cg, logs) + logs = _create_col_for_each_root(logs) + logs = _nested_lists_to_string(logs) + print(f"total logs {len(logs)}") + + count = 0 + failed_count = 0 + # datastore limits 500 entities per request + for chunk in chunked(logs, 500): + entities = [] + removed_edges = [] + for log in chunk: + kind = cg.graph_id + if log["status"] == 4: + kind = f"{cg.graph_id}_failed" + failed_count += 1 + op_id = log.pop("id") + op_log = datastore.Entity( + key=client.key(kind, op_id, namespace=config.NAMESPACE), + exclude_from_indexes=config.EXCLUDE_FROM_INDICES, + ) + removed_edges.append((op_id, log.pop("removed_edges", []))) + op_log.update(log) + entities.append(op_log) + client.put_multi(entities) + count += len(entities) + _write_removed_edges(f"{cg.meta.data_source.EDGES}/removed", removed_edges) + _update_stats(cg.graph_id, client, config, last_export_key, count, export_ts) + return failed_count + + +def _update_stats( + graph_id: str, + client: datastore.Client, + config: OperationLogsConfig, + last_export_key: datastore.Key, + logs_count: int, + export_ts: datetime, +): + export_log = datastore.Entity( + key=last_export_key, exclude_from_indexes=config.EXPORT.EXCLUDE_FROM_INDICES + ) + export_log[config.EXPORT.LAST_EXPORT_TS] = export_ts + export_log[config.EXPORT.LOGS_COUNT] = logs_count + + this_export_key = client.key( + config.EXPORT.KIND, + f"{graph_id}_{int(export_ts.timestamp())}", + namespace=config.NAMESPACE, + ) + this_export_log = datastore.Entity( + key=this_export_key, exclude_from_indexes=config.EXPORT.EXCLUDE_FROM_INDICES + ) + this_export_log[config.EXPORT.LAST_EXPORT_TS] = export_ts + this_export_log[config.EXPORT.LOGS_COUNT] = logs_count + + client.put_multi([export_log, this_export_log]) + print(f"export time {export_ts}, count {logs_count}") diff --git a/pychunkedgraph/export/to/datastore/config.py b/pychunkedgraph/export/to/datastore/config.py new file mode 100644 index 000000000..7a7e88da8 --- /dev/null +++ b/pychunkedgraph/export/to/datastore/config.py @@ -0,0 +1,35 @@ +from collections import namedtuple + +DEFAULT_NS = "pychunkedgraph_operation_logs" + + +_export_info_fields = ("KIND", "LOGS_COUNT", "LAST_EXPORT_TS", "EXCLUDE_FROM_INDICES") +_export_info_defaults = ( + "export_info", + "logs_count", + "last_export_ts", + ("logs_count", "last_export_ts",), +) +ExportInfo = namedtuple( + "ExportInfo", _export_info_fields, defaults=_export_info_defaults, +) + + +_operation_log_fields = ("NAMESPACE", "EXPORT", "EXCLUDE_FROM_INDICES") +_operation_log_defaults = ( + DEFAULT_NS, + ExportInfo(), + ( + "added_edges", + "removed_edges", + "source_coords", + "sink_coords", + "exception", + "source_ids", + "sink_ids", + "bb_offset", + ), +) +OperationLogsConfig = namedtuple( + "OperationLogsConfig", _operation_log_fields, defaults=_operation_log_defaults, +) diff --git a/pychunkedgraph/exporting/export.py b/pychunkedgraph/exporting/export.py deleted file mode 100644 index 53c6471cd..000000000 --- a/pychunkedgraph/exporting/export.py +++ /dev/null @@ -1,222 +0,0 @@ -import numpy as np -import cloudvolume -import itertools -import dill - -from pychunkedgraph.backend import chunkedgraph -from pychunkedgraph.backend.utils import serializers, column_keys -from multiwrapper import multiprocessing_utils as mu - - -def get_sv_to_root_id_mapping_chunk(cg, chunk_coords, vol=None): - """ Acquires a svid -> rootid dictionary for a chunk - - :param cg: chunkedgraph instance - :param chunk_coords: list - :return: dict - """ - sv_to_root_mapping = {} - - chunk_coords = np.array(chunk_coords, dtype=np.int) - - if np.any((chunk_coords % cg.chunk_size) != 0): - raise Exception("Chunk coords have to match a chunk corner exactly") - - chunk_coords = chunk_coords / cg.chunk_size - chunk_coords = chunk_coords.astype(np.int) - bb = np.array([chunk_coords, chunk_coords + 1], dtype=np.int) - - remapped_vol = None - vol_shape = None - - if vol is not None: - vol_shape = vol.shape - vol = vol.flatten() - remapped_vol = np.zeros_like(vol) - - atomic_rows = cg.range_read_chunk(layer=1, x=chunk_coords[0], - y=chunk_coords[1], z=chunk_coords[2], - columns=column_keys.Hierarchy.Parent) - for atomic_id in atomic_rows.keys(): - # Check if already found the root for this supervoxel - if atomic_id in sv_to_root_mapping: - continue - - # Find root - root_id = cg.get_root(atomic_id) - sv_to_root_mapping[atomic_id] = root_id - - # Add atomic children of root_id - atomic_ids = cg.get_subgraph_nodes(root_id, bounding_box=bb, - bb_is_coordinate=False) - sv_to_root_mapping.update(dict(zip(atomic_ids, - [root_id] * len(atomic_ids)))) - - if remapped_vol is not None: - remapped_vol[np.in1d(vol, atomic_ids)] = root_id - - if remapped_vol is not None: - remapped_vol = remapped_vol.reshape(vol_shape) - return sv_to_root_mapping, remapped_vol - else: - return sv_to_root_mapping - - -def _write_flat_segmentation_thread(args): - """ Helper of write_flat_segmentation """ - cg_info, start_block, end_block, from_url, to_url, mip = args - - assert 'segmentation' in to_url - assert 'svenmd' in to_url - - from_cv = cloudvolume.CloudVolume(from_url, mip=mip) - to_cv = cloudvolume.CloudVolume(to_url, mip=mip) - - cg = chunkedgraph.ChunkedGraph(table_id=cg_info["table_id"], - instance_id=cg_info["instance_id"], - project_id=cg_info["project_id"], - credentials=cg_info["credentials"]) - - for block_z in range(start_block[2], end_block[2]): - z_start = block_z * cg.chunk_size[2] - z_end = (block_z + 1) * cg.chunk_size[2] - for block_y in range(start_block[1], end_block[1]): - y_start = block_y * cg.chunk_size[1] - y_end = (block_y + 1) * cg.chunk_size[1] - for block_x in range(start_block[0], end_block[0]): - x_start = block_x * cg.chunk_size[0] - x_end = (block_x + 1) * cg.chunk_size[0] - - block = from_cv[x_start: x_end, y_start: y_end, z_start: z_end] - - _, remapped_block = get_sv_to_root_id_mapping_chunk(cg, [x_start, y_start, z_start], block) - - to_cv[x_start: x_end, y_start: y_end, z_start: z_end] = remapped_block - - -def write_flat_segmentation(cg, dataset_name, bounding_box=None, block_factor=2, - n_threads=1, mip=0): - """ Applies the mapping in the chunkedgraph to the supervoxels to create - a flattened segmentation - - :param cg: chunkedgraph instance - :param dataset_name: str - :param bounding_box: np.array - :param block_factor: int - :param n_threads: int - :param mip: int - :return: bool - """ - - if dataset_name == "pinky": - from_url = "gs://neuroglancer/svenmd/pinky40_v11/watershed/" - to_url = "gs://neuroglancer/svenmd/pinky40_v11/segmentation/" - elif dataset_name == "basil": - from_url = "gs://neuroglancer/svenmd/basil_4k_oldnet_cg/watershed/" - to_url = "gs://neuroglancer/svenmd/basil_4k_oldnet_cg/segmentation/" - else: - raise Exception("Dataset unknown") - - from_cv = cloudvolume.CloudVolume(from_url, mip=mip) - - dataset_bounding_box = np.array(from_cv.bounds.to_list()) - - block_bounding_box_cg = \ - [np.floor(dataset_bounding_box[:3] / cg.chunk_size).astype(np.int), - np.ceil(dataset_bounding_box[3:] / cg.chunk_size).astype(np.int)] - - if bounding_box is not None: - bounding_box_cg = \ - [np.floor(bounding_box[0] / cg.chunk_size).astype(np.int), - np.ceil(bounding_box[1] / cg.chunk_size).astype(np.int)] - - m = block_bounding_box_cg[0] < bounding_box_cg[0] - block_bounding_box_cg[0][m] = bounding_box_cg[0][m] - - m = block_bounding_box_cg[1] > bounding_box_cg[1] - block_bounding_box_cg[1][m] = bounding_box_cg[1][m] - - block_iter = itertools.product(np.arange(block_bounding_box_cg[0][0], - block_bounding_box_cg[1][0], - block_factor), - np.arange(block_bounding_box_cg[0][1], - block_bounding_box_cg[1][1], - block_factor), - np.arange(block_bounding_box_cg[0][2], - block_bounding_box_cg[1][2], - block_factor)) - blocks = np.array(list(block_iter)) - - cg_info = cg.get_serialized_info() - - multi_args = [] - for start_block in blocks: - end_block = start_block + block_factor - m = end_block > block_bounding_box_cg[1] - end_block[m] = block_bounding_box_cg[1][m] - - multi_args.append([cg_info, start_block, end_block, - from_url, to_url, mip]) - - # Run parallelizing - if n_threads == 1: - mu.multiprocess_func(_write_flat_segmentation_thread, multi_args, - n_threads=n_threads, verbose=True, - debug=n_threads == 1) - else: - mu.multisubprocess_func(_write_flat_segmentation_thread, multi_args, - n_threads=n_threads) - - -def export_changelog(cg, path=None): - """ Exports all changes to binary dill file - - :param cg: ChunkedGraph instance - :param path: str - :return: bool - """ - - operations = cg.read_node_id_rows(start_id=np.uint64(0), - end_id=cg.get_max_operation_id(), - end_id_inclusive=True) - - if path is not None: - with open(path, "wb") as f: - dill.dump(operations, f) - else: - return operations - - -def load_changelog(path): - """ Loads stored changelog - - :param path: str - :return: - """ - - with open(path, "rb") as f: - operations = dill.load(f) - - # Dill can marshall the `serializer` functions used for each column, but their address - # won't match anymore, which breaks the hash lookup for `_Column`s. Hence, we simply create - # new `_Column`s from the old `family_id` and `key` - for operation_id, column_dict in operations.items(): - operations[operation_id] = \ - {column_keys.from_key(k.family_id, k.key): v for (k, v) in column_dict.items()} - - return operations - - -def get_log_diff(log_old, log_new): - """ Computes a simple difference between two logs - - :param log_old: dict - :param log_new: dict - :return: dict - """ - log = log_new.copy() - - for k in log_old: - del log[k] - - return log diff --git a/pychunkedgraph/graph/__init__.py b/pychunkedgraph/graph/__init__.py new file mode 100644 index 000000000..96b342427 --- /dev/null +++ b/pychunkedgraph/graph/__init__.py @@ -0,0 +1,2 @@ +from .chunkedgraph import ChunkedGraph +from .meta import ChunkedGraphMeta diff --git a/pychunkedgraph/benchmarking/__init__.py b/pychunkedgraph/graph/analysis/__init__.py similarity index 100% rename from pychunkedgraph/benchmarking/__init__.py rename to pychunkedgraph/graph/analysis/__init__.py diff --git a/pychunkedgraph/graph/analysis/pathing.py b/pychunkedgraph/graph/analysis/pathing.py new file mode 100644 index 000000000..062b7a1c3 --- /dev/null +++ b/pychunkedgraph/graph/analysis/pathing.py @@ -0,0 +1,227 @@ +import typing + +import fastremap +import graph_tool +import numpy as np + +from pychunkedgraph.graph.utils import flatgraph + +from ..subgraph import get_subgraph_nodes + + +def get_first_shared_parent( + cg, first_node_id: np.uint64, second_node_id: np.uint64, time_stamp=None +): + """ + Get the common parent of first_node_id and second_node_id with the lowest layer. + Returns None if the two nodes belong to different root ids. + :param first_node_id: np.uint64 + :param second_node_id: np.uint64 + :return: np.uint64 or None + """ + first_node_parent_ids = set() + second_node_parent_ids = set() + cur_first_node_parent = first_node_id + cur_second_node_parent = second_node_id + while cur_first_node_parent is not None or cur_second_node_parent is not None: + if cur_first_node_parent is not None: + first_node_parent_ids.add(cur_first_node_parent) + if cur_second_node_parent is not None: + second_node_parent_ids.add(cur_second_node_parent) + if cur_first_node_parent in second_node_parent_ids: + return cur_first_node_parent + if cur_second_node_parent in first_node_parent_ids: + return cur_second_node_parent + if cur_first_node_parent is not None: + cur_first_node_parent = cg.get_parent( + cur_first_node_parent, time_stamp=time_stamp + ) + if cur_second_node_parent is not None: + cur_second_node_parent = cg.get_parent( + cur_second_node_parent, time_stamp=time_stamp + ) + return None + + +def get_children_at_layer( + cg, + agglomeration_id: np.uint64, + layer: int, + allow_lower_layers: bool = False, +): + """ + Get the children of agglomeration_id that have layer = layer. + :param agglomeration_id: np.uint64 + :param layer: int + :return: [np.uint64] + """ + nodes_to_query = [agglomeration_id] + children_at_layer = [] + while True: + children = cg.get_children(nodes_to_query, flatten=True) + children_layers = cg.get_chunk_layers(children) + if allow_lower_layers: + stop_layer_mask = children_layers <= layer + else: + stop_layer_mask = children_layers == layer + continue_layer_mask = children_layers > layer + found_children_at_layer = children[stop_layer_mask] + children_at_layer.append(found_children_at_layer) + nodes_to_query = children[continue_layer_mask] + if not np.any(nodes_to_query): + break + return np.concatenate(children_at_layer) + + +def get_lvl2_edge_list( + cg, + node_id: np.uint64, + bbox: typing.Optional[typing.Sequence[typing.Sequence[int]]] = None, +): + """get an edge list of lvl2 ids for a particular node + + :param cg: ChunkedGraph object + :param node_id: np.uint64 that you want the edge list for + :param bbox: Optional[Sequence[Sequence[int]]] a bounding box to limit the search + """ + + if bbox is None: + # maybe temporary, this was the old implementation + lvl2_ids = get_children_at_layer(cg, node_id, 2) + else: + lvl2_ids = get_subgraph_nodes( + cg, + node_id, + bbox=bbox, + bbox_is_coordinate=True, + return_layers=[2], + return_flattened=True, + ) + + edges = _get_edges_for_lvl2_ids(cg, lvl2_ids, induced=True) + return edges + + +def _get_edges_for_lvl2_ids(cg, lvl2_ids, induced=False): + # protect in case there are no lvl2 ids + if len(lvl2_ids) == 0: + return np.empty((0, 2), dtype=np.uint64) + + cce_dict = cg.get_atomic_cross_edges(lvl2_ids) + + # Gather all of the supervoxel ids into two lists, we will map them to + # their parent lvl2 ids + edge_array = [] + for l2_id in cce_dict: + for level in cce_dict[l2_id]: + edge_array.append(cce_dict[l2_id][level]) + + # protect in case there are no edges + if len(edge_array) == 0: + return np.empty((0, 2), dtype=np.uint64) + + edge_array = np.concatenate(edge_array) + known_supervoxels_list = [] + known_l2_list = [] + unknown_supervoxel_list = [] + for lvl2_id in cce_dict: + for level in cce_dict[lvl2_id]: + known_supervoxels_for_lv2_id = cce_dict[lvl2_id][level][:, 0] + unknown_supervoxels_for_lv2_id = cce_dict[lvl2_id][level][:, 1] + known_supervoxels_list.append(known_supervoxels_for_lv2_id) + known_l2_list.append(np.full(known_supervoxels_for_lv2_id.shape, lvl2_id)) + unknown_supervoxel_list.append(unknown_supervoxels_for_lv2_id) + + # Create two arrays to map supervoxels for which we know their parents + known_supervoxel_array, unique_indices = np.unique( + np.concatenate(known_supervoxels_list), return_index=True + ) + known_l2_array = (np.concatenate(known_l2_list))[unique_indices] + unknown_supervoxel_array = np.unique(np.concatenate(unknown_supervoxel_list)) + + # Call get_parents on any supervoxels for which we don't know their parents + supervoxels_to_query_parent = np.setdiff1d( + unknown_supervoxel_array, known_supervoxel_array + ) + if len(supervoxels_to_query_parent) > 0: + missing_l2_ids = cg.get_parents(supervoxels_to_query_parent) + known_supervoxel_array = np.concatenate( + (known_supervoxel_array, supervoxels_to_query_parent) + ) + known_l2_array = np.concatenate((known_l2_array, missing_l2_ids)) + + # Map the cross-chunk edges from supervoxels to lvl2 ids + edge_view = edge_array.view() + edge_view.shape = -1 + fastremap.remap_from_array_kv(edge_view, known_supervoxel_array, known_l2_array) + + edge_array = np.unique(np.sort(edge_array, axis=1), axis=0) + + if induced: + # make this an induced subgraph + # keep only the edges that are between the lvl2 ids asked for + edge_array = edge_array[ + np.isin(edge_array[:, 0], lvl2_ids) & np.isin(edge_array[:, 1], lvl2_ids) + ] + + return edge_array + + +def find_l2_shortest_path( + cg, source_l2_id: np.uint64, target_l2_id: np.uint64, time_stamp=None +): + """ + Find a path of level 2 ids that connect two level 2 node ids through cross chunk edges. + Return a list of level 2 ids representing this path. + Return None if the two level 2 ids do not belong to the same object. + :param cg: ChunkedGraph object + :param source_l2_id: np.uint64 + :param target_l2_id: np.uint64 + :return: [np.uint64] or None + """ + # Get the cross-chunk edges that we need to build the graph + shared_parent_id = get_first_shared_parent( + cg, source_l2_id, target_l2_id, time_stamp + ) + if shared_parent_id is None: + return None + + edge_array = get_lvl2_edge_list(cg, shared_parent_id) + # Create a graph-tool graph of the mapped cross-chunk-edges + weighted_graph, _, _, graph_indexed_l2_ids = flatgraph.build_gt_graph( + edge_array, is_directed=False + ) + + # Find the shortest path from the source_l2_id to the target_l2_id + source_graph_id = np.where(graph_indexed_l2_ids == source_l2_id)[0][0] + target_graph_id = np.where(graph_indexed_l2_ids == target_l2_id)[0][0] + source_vertex = weighted_graph.vertex(source_graph_id) + target_vertex = weighted_graph.vertex(target_graph_id) + vertex_list, _ = graph_tool.topology.shortest_path( + weighted_graph, source=source_vertex, target=target_vertex + ) + + # Remap the graph-tool ids to lvl2 ids and return the path + vertex_indices = [weighted_graph.vertex_index[vertex] for vertex in vertex_list] + l2_traversal_path = graph_indexed_l2_ids[vertex_indices] + return l2_traversal_path + + +def compute_rough_coordinate_path(cg, l2_ids): + """ + Given a list of l2_ids, return a list of rough coordinates representing + the path the l2_ids form. + :param cg: ChunkedGraph object + :param l2_ids: Sequence[np.uint64] + :return: [np.ndarray] + """ + coordinate_path = [] + for l2_id in l2_ids: + chunk_center = cg.get_chunk_coordinates(l2_id) + np.array([0.5, 0.5, 0.5]) + coordinate = chunk_center * np.array( + cg.meta.graph_config.CHUNK_SIZE + ) + np.array(cg.meta.cv.mip_voxel_offset(0)) + coordinate = coordinate * np.array(cg.meta.cv.mip_resolution(0)) + coordinate = coordinate.astype(np.float32) + coordinate_path.append(coordinate) + return coordinate_path diff --git a/pychunkedgraph/graph/attributes.py b/pychunkedgraph/graph/attributes.py new file mode 100644 index 000000000..3e48d204a --- /dev/null +++ b/pychunkedgraph/graph/attributes.py @@ -0,0 +1,284 @@ +# TODO design to use these attributes across different clients +# `family_id` is specific to bigtable + +from typing import NamedTuple + +from .utils import serializers +from .utils import basetypes + + +class _AttributeType(NamedTuple): + key: bytes + family_id: str + serializer: serializers._Serializer + + +class _Attribute(_AttributeType): + __slots__ = () + _attributes = {} + + def __init__(self, **kwargs): + super().__init__() + _Attribute._attributes[(kwargs["family_id"], kwargs["key"])] = self + + def serialize(self, obj): + return self.serializer.serialize(obj) + + def deserialize(self, stream): + return self.serializer.deserialize(stream) + + @property + def basetype(self): + return self.serializer.basetype + + @property + def index(self): + return int(self.key.decode("utf-8").split("_")[-1]) + + +class _AttributeArray: + _attributearrays = {} + + def __init__(self, pattern, family_id, serializer): + self._pattern = pattern + self._family_id = family_id + self._serializer = serializer + _AttributeArray._attributearrays[(family_id, pattern)] = self + + # TODO: Add missing check in `fromkey(family_id, key)` and remove this + # loop (pre-creates `_Attributes`, so that the inverse lookup works) + for i in range(20): + self[i] # pylint: disable=W0104 + + def __getitem__(self, item): + return _Attribute( + key=self.pattern % item, + family_id=self._family_id, + serializer=self._serializer, + ) + + @property + def pattern(self): + return self._pattern + + @property + def serialize(self): + return self._serializer.serialize + + @property + def deserialize(self): + return self._serializer.deserialize + + @property + def basetype(self): + return self._serializer.basetype + + +class Concurrency: + Counter = _Attribute( + key=b"counter", + family_id="1", + serializer=serializers.NumPyValue(dtype=basetypes.COUNTER), + ) + + Lock = _Attribute(key=b"lock", family_id="0", serializer=serializers.UInt64String()) + + IndefiniteLock = _Attribute( + key=b"indefinite_lock", family_id="0", serializer=serializers.UInt64String() + ) + + +class Connectivity: + Affinity = _Attribute( + key=b"affinities", + family_id="0", + serializer=serializers.NumPyArray(dtype=basetypes.EDGE_AFFINITY), + ) + + Area = _Attribute( + key=b"areas", + family_id="0", + serializer=serializers.NumPyArray(dtype=basetypes.EDGE_AREA), + ) + + CrossChunkEdge = _AttributeArray( + pattern=b"atomic_cross_edges_%d", + family_id="3", + serializer=serializers.NumPyArray( + dtype=basetypes.NODE_ID, shape=(-1, 2), compression_level=22 + ), + ) + + FakeEdges = _Attribute( + key=b"fake_edges", + family_id="3", + serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID, shape=(-1, 2)), + ) + + +class Hierarchy: + Child = _Attribute( + key=b"children", + family_id="0", + serializer=serializers.NumPyArray( + dtype=basetypes.NODE_ID, compression_level=22 + ), + ) + + FormerParent = _Attribute( + key=b"former_parents", + family_id="0", + serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID), + ) + + NewParent = _Attribute( + key=b"new_parents", + family_id="0", + serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID), + ) + + Parent = _Attribute( + key=b"parents", + family_id="0", + serializer=serializers.NumPyValue(dtype=basetypes.NODE_ID), + ) + + +class GraphMeta: + key = b"meta" + Meta = _Attribute(key=key, family_id="0", serializer=serializers.Pickle()) + + +class GraphVersion: + key = b"version" + Version = _Attribute(key=key, family_id="0", serializer=serializers.String("utf-8")) + + +class OperationLogs: + key = b"ioperations" + + from enum import Enum + + class StatusCodes(Enum): + SUCCESS = 0 # all is well, new changes persisted + CREATED = 1 # log record created in storage + EXCEPTION = 2 # edit unsuccessful, unknown error + WRITE_STARTED = 3 # edit successful, start persisting changes + WRITE_FAILED = 4 # edit successful, but changes not persisted + + OperationID = _Attribute( + key=b"operation_id", family_id="0", serializer=serializers.UInt64String() + ) + + UndoOperationID = _Attribute( + key=b"undo_operation_id", family_id="2", serializer=serializers.UInt64String() + ) + + RedoOperationID = _Attribute( + key=b"redo_operation_id", family_id="2", serializer=serializers.UInt64String() + ) + + UserID = _Attribute( + key=b"user", family_id="2", serializer=serializers.String("utf-8") + ) + + RootID = _Attribute( + key=b"roots", + family_id="2", + serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID), + ) + + SourceID = _Attribute( + key=b"source_ids", + family_id="2", + serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID), + ) + + SinkID = _Attribute( + key=b"sink_ids", + family_id="2", + serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID), + ) + + SourceCoordinate = _Attribute( + key=b"source_coords", + family_id="2", + serializer=serializers.NumPyArray(dtype=basetypes.COORDINATES, shape=(-1, 3)), + ) + + SinkCoordinate = _Attribute( + key=b"sink_coords", + family_id="2", + serializer=serializers.NumPyArray(dtype=basetypes.COORDINATES, shape=(-1, 3)), + ) + + BoundingBoxOffset = _Attribute( + key=b"bb_offset", + family_id="2", + serializer=serializers.NumPyArray(dtype=basetypes.COORDINATES), + ) + + AddedEdge = _Attribute( + key=b"added_edges", + family_id="2", + serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID, shape=(-1, 2)), + ) + + RemovedEdge = _Attribute( + key=b"removed_edges", + family_id="2", + serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID, shape=(-1, 2)), + ) + + Affinity = _Attribute( + key=b"affinities", + family_id="2", + serializer=serializers.NumPyArray(dtype=basetypes.EDGE_AFFINITY), + ) + + Status = _Attribute( + key=b"operation_status", family_id="0", serializer=serializers.Pickle() + ) + + OperationException = _Attribute( + key=b"operation_exception", + family_id="0", + serializer=serializers.String("utf-8"), + ) + + # timestamp at which the new IDs were created during the operation + # this is needed because the timestamp of the operation log + # will change with change in status + OperationTimeStamp = _Attribute( + key=b"operation_ts", family_id="0", serializer=serializers.Pickle() + ) + + @staticmethod + def all(): + return [ + OperationLogs.OperationID, + OperationLogs.UndoOperationID, + OperationLogs.RedoOperationID, + OperationLogs.UserID, + OperationLogs.RootID, + OperationLogs.SourceID, + OperationLogs.SinkID, + OperationLogs.SourceCoordinate, + OperationLogs.SinkCoordinate, + OperationLogs.BoundingBoxOffset, + OperationLogs.AddedEdge, + OperationLogs.RemovedEdge, + OperationLogs.Affinity, + OperationLogs.Status, + OperationLogs.OperationException, + OperationLogs.OperationTimeStamp, + ] + + +def from_key(family_id: str, key: bytes): + try: + return _Attribute._attributes[(family_id, key)] + except KeyError: + # FIXME: Look if the key matches a columnarray pattern and + # remove loop initialization in _AttributeArray.__init__() + raise KeyError(f"Unknown key {family_id}:{key.decode()}") diff --git a/pychunkedgraph/graph/cache.py b/pychunkedgraph/graph/cache.py new file mode 100644 index 000000000..f60b6ca92 --- /dev/null +++ b/pychunkedgraph/graph/cache.py @@ -0,0 +1,124 @@ +""" +Cache nodes, parents, children and cross edges. +""" +from sys import maxsize +from datetime import datetime + +from cachetools import cached +from cachetools import LRUCache + +import numpy as np + + +from .utils.basetypes import NODE_ID + + +def update(cache, keys, vals): + try: + # 1 to 1 + for k, v in zip(keys, vals): + cache[k] = v + except TypeError: + # many to 1 + for k in keys: + cache[k] = vals + + +class CacheService: + def __init__(self, cg): + self._cg = cg + + self._parent_vec = np.vectorize(self.parent, otypes=[np.uint64]) + self._children_vec = np.vectorize(self.children, otypes=[np.ndarray]) + self._atomic_cross_edges_vec = np.vectorize( + self.atomic_cross_edges, otypes=[dict] + ) + + # no limit because we don't want to lose new IDs + self.parents_cache = LRUCache(maxsize=maxsize) + self.children_cache = LRUCache(maxsize=maxsize) + self.atomic_cx_edges_cache = LRUCache(maxsize=maxsize) + + def __len__(self): + return ( + len(self.parents_cache) + + len(self.children_cache) + + len(self.atomic_cx_edges_cache) + ) + + def clear(self): + self.parents_cache.clear() + self.children_cache.clear() + self.atomic_cx_edges_cache.clear() + + def parent(self, node_id: np.uint64, *, time_stamp: datetime = None): + @cached(cache=self.parents_cache, key=lambda node_id: node_id) + def parent_decorated(node_id): + return self._cg.get_parent(node_id, raw_only=True, time_stamp=time_stamp) + + return parent_decorated(node_id) + + def children(self, node_id): + @cached(cache=self.children_cache, key=lambda node_id: node_id) + def children_decorated(node_id): + children = self._cg.get_children(node_id, raw_only=True) + update(self.parents_cache, children, node_id) + return children + + return children_decorated(node_id) + + def atomic_cross_edges(self, node_id): + @cached(cache=self.atomic_cx_edges_cache, key=lambda node_id: node_id) + def atomic_cross_edges_decorated(node_id): + edges = self._cg.get_atomic_cross_edges( + np.array([node_id], dtype=NODE_ID), raw_only=True + ) + return edges[node_id] + + return atomic_cross_edges_decorated(node_id) + + def parents_multiple(self, node_ids: np.ndarray, *, time_stamp: datetime = None): + if not node_ids.size: + return node_ids + mask = np.in1d(node_ids, np.fromiter(self.parents_cache.keys(), dtype=NODE_ID)) + parents = node_ids.copy() + parents[mask] = self._parent_vec(node_ids[mask]) + parents[~mask] = self._cg.get_parents( + node_ids[~mask], raw_only=True, time_stamp=time_stamp + ) + update(self.parents_cache, node_ids[~mask], parents[~mask]) + return parents + + def children_multiple(self, node_ids: np.ndarray, *, flatten=False): + result = {} + if not node_ids.size: + return result + mask = np.in1d(node_ids, np.fromiter(self.children_cache.keys(), dtype=NODE_ID)) + cached_children_ = self._children_vec(node_ids[mask]) + result.update({id_: c_ for id_, c_ in zip(node_ids[mask], cached_children_)}) + result.update(self._cg.get_children(node_ids[~mask], raw_only=True)) + update( + self.children_cache, node_ids[~mask], [result[k] for k in node_ids[~mask]] + ) + if flatten: + return np.concatenate([*result.values()]) + return result + + def atomic_cross_edges_multiple(self, node_ids: np.ndarray): + result = {} + if not node_ids.size: + return result + mask = np.in1d( + node_ids, np.fromiter(self.atomic_cx_edges_cache.keys(), dtype=NODE_ID) + ) + cached_edges_ = self._atomic_cross_edges_vec(node_ids[mask]) + result.update( + {id_: edges_ for id_, edges_ in zip(node_ids[mask], cached_edges_)} + ) + result.update(self._cg.get_atomic_cross_edges(node_ids[~mask], raw_only=True)) + update( + self.atomic_cx_edges_cache, + node_ids[~mask], + [result[k] for k in node_ids[~mask]], + ) + return result diff --git a/pychunkedgraph/graph/chunkedgraph.py b/pychunkedgraph/graph/chunkedgraph.py new file mode 100644 index 000000000..210bff50b --- /dev/null +++ b/pychunkedgraph/graph/chunkedgraph.py @@ -0,0 +1,1022 @@ +# pylint: disable=invalid-name, missing-docstring, too-many-lines, import-outside-toplevel + +import time +import typing +import datetime + +import numpy as np +from pychunkedgraph import __version__ + +from . import types +from . import operation +from . import attributes +from . import exceptions +from .client import base +from .client import BigTableClient +from .client import BackendClientInfo +from .client import get_default_client_info +from .cache import CacheService +from .meta import ChunkedGraphMeta +from .utils import basetypes +from .utils import id_helpers +from .utils import generic as misc_utils +from .edges import Edges +from .edges import utils as edge_utils +from .chunks import utils as chunk_utils +from .chunks import hierarchy as chunk_hierarchy + + +class ChunkedGraph: + def __init__( + self, + *, + graph_id: str = None, + meta: ChunkedGraphMeta = None, + client_info: BackendClientInfo = get_default_client_info(), + ): + """ + 1. New graph + Requires `meta`; if `client_info` is not passed the default client is used. + After creating `ChunkedGraph` instance, run instance.create(). + 2. Existing graph in default client + Requires `graph_id`. + 3. Existing graphs in other projects/clients, + Requires `graph_id` and `client_info`. + """ + # create client based on type + # for now, just use BigTableClient + + if meta: + graph_id = meta.graph_config.ID_PREFIX + meta.graph_config.ID + bt_client = BigTableClient( + graph_id, config=client_info.CONFIG, graph_meta=meta + ) + self._meta = meta + else: + bt_client = BigTableClient(graph_id, config=client_info.CONFIG) + self._meta = bt_client.read_graph_meta() + + self._client = bt_client + self._id_client = bt_client + self._cache_service = None + self.mock_edges = None # hack for unit tests + + @property + def meta(self) -> ChunkedGraphMeta: + return self._meta + + @property + def graph_id(self) -> str: + return self.meta.graph_config.ID_PREFIX + self.meta.graph_config.ID + + @property + def version(self) -> str: + return self.client.read_graph_version() + + @property + def client(self) -> base.SimpleClient: + return self._client + + @property + def id_client(self) -> base.ClientWithIDGen: + return self._id_client + + @property + def cache(self): + return self._cache_service + + @property + def segmentation_resolution(self) -> np.ndarray: + return np.array(self.meta.ws_cv.scale["resolution"]) + + @cache.setter + def cache(self, cache_service: CacheService): + self._cache_service = cache_service + + def create(self): + """Creates the graph in storage client and stores meta.""" + self._client.create_graph(self._meta, version=__version__) + + def update_meta(self, meta: ChunkedGraphMeta, overwrite: bool): + """Update meta of an already existing graph.""" + self.client.update_graph_meta(meta, overwrite=overwrite) + + def range_read_chunk( + self, + chunk_id: basetypes.CHUNK_ID, + properties: typing.Optional[ + typing.Union[typing.Iterable[attributes._Attribute], attributes._Attribute] + ] = None, + time_stamp: typing.Optional[datetime.datetime] = None, + ) -> typing.Dict: + """Read all nodes in a chunk.""" + layer = self.get_chunk_layer(chunk_id) + root_chunk = layer == self.meta.layer_count + max_node_id = self.id_client.get_max_node_id(chunk_id=chunk_id, root_chunk=root_chunk) + if layer == 1: + max_node_id = chunk_id | self.get_segment_id_limit(chunk_id) # pylint: disable=unsupported-binary-operation + + return self.client.read_nodes( + start_id=self.get_node_id(np.uint64(0), chunk_id=chunk_id), + end_id=max_node_id, + end_id_inclusive=True, + properties=properties, + end_time=time_stamp, + end_time_inclusive=True, + ) + + def get_atomic_id_from_coord( + self, + x: int, + y: int, + z: int, + parent_id: np.uint64, + n_tries: int = 5, + time_stamp: typing.Optional[datetime.datetime] = None, + ) -> np.uint64: + """Determines atomic id given a coordinate.""" + if self.get_chunk_layer(parent_id) == 1: + return parent_id + return id_helpers.get_atomic_id_from_coord( + self.meta, + self.get_root, + x, + y, + z, + parent_id, + n_tries=n_tries, + time_stamp=time_stamp, + ) + + def get_atomic_ids_from_coords( + self, + coordinates: typing.Sequence[typing.Sequence[int]], + parent_id: np.uint64, + max_dist_nm: int = 150, + ) -> typing.Sequence[np.uint64]: + """Retrieves supervoxel ids for multiple coords. + + :param coordinates: n x 3 np.ndarray of locations in voxel space + :param parent_id: parent id common to all coordinates at any layer + :param max_dist_nm: max distance explored + :return: supervoxel ids; returns None if no solution was found + """ + if self.get_chunk_layer(parent_id) == 1: + return np.array([parent_id] * len(coordinates), dtype=np.uint64) + + # Enable search with old parent by using its timestamp and map to parents + parent_ts = self.get_node_timestamps([parent_id], return_numpy=False)[0] + return id_helpers.get_atomic_ids_from_coords( + self.meta, + coordinates, + parent_id, + self.get_chunk_layer(parent_id), + parent_ts, + self.get_roots, + max_dist_nm, + ) + + def get_parents( + self, + node_ids: typing.Sequence[np.uint64], + *, + raw_only=False, + current: bool = True, + fail_to_zero: bool = False, + time_stamp: typing.Optional[datetime.datetime] = None, + ): + """ + If current=True returns only the latest parents. + Else all parents along with timestamps. + """ + if raw_only or not self.cache: + time_stamp = misc_utils.get_valid_timestamp(time_stamp) + parent_rows = self.client.read_nodes( + node_ids=node_ids, + properties=attributes.Hierarchy.Parent, + end_time=time_stamp, + end_time_inclusive=True, + ) + if not parent_rows: + return types.empty_1d + + parents = [] + if current: + for id_ in node_ids: + try: + parents.append(parent_rows[id_][0].value) + except KeyError as exc: + if fail_to_zero: + parents.append(0) + else: + raise KeyError from exc + parents = np.array(parents, dtype=basetypes.NODE_ID) + else: + for id_ in node_ids: + try: + parents.append( + [(p.value, p.timestamp) for p in parent_rows[id_]] + ) + except KeyError as exc: + if fail_to_zero: + parents.append([(0, datetime.datetime.fromtimestamp(0))]) + else: + raise KeyError from exc + return parents + return self.cache.parents_multiple(node_ids, time_stamp=time_stamp) + + def get_parent( + self, + node_id: np.uint64, + *, + raw_only=False, + latest: bool = True, + time_stamp: typing.Optional[datetime.datetime] = None, + ) -> typing.Union[typing.List[typing.Tuple], np.uint64]: + if raw_only or not self.cache: + time_stamp = misc_utils.get_valid_timestamp(time_stamp) + parents = self.client.read_node( + node_id, + properties=attributes.Hierarchy.Parent, + end_time=time_stamp, + end_time_inclusive=True, + ) + if not parents: + return None + if latest: + return parents[0].value + return [(p.value, p.timestamp) for p in parents] + return self.cache.parent(node_id, time_stamp=time_stamp) + + def get_children( + self, + node_id_or_ids: typing.Union[typing.Iterable[np.uint64], np.uint64], + *, + raw_only=False, + flatten: bool = False, + ) -> typing.Union[typing.Dict, np.ndarray]: + """ + Children for the specified NodeID or NodeIDs. + If flatten == True, an array is returned, else a dict {node_id: children}. + """ + if np.isscalar(node_id_or_ids): + if raw_only or not self.cache: + children = self.client.read_node( + node_id=node_id_or_ids, properties=attributes.Hierarchy.Child + ) + if not children: + return types.empty_1d.copy() + return children[0].value + return self.cache.children(node_id_or_ids) + node_children_d = self._get_children_multiple(node_id_or_ids, raw_only=raw_only) + if flatten: + if not node_children_d: + return types.empty_1d.copy() + return np.concatenate(list(node_children_d.values())) + return node_children_d + + def _get_children_multiple( + self, node_ids: typing.Iterable[np.uint64], *, raw_only=False + ) -> typing.Dict: + if raw_only or not self.cache: + node_children_d = self.client.read_nodes( + node_ids=node_ids, properties=attributes.Hierarchy.Child + ) + return { + x: node_children_d[x][0].value + if x in node_children_d + else types.empty_1d.copy() + for x in node_ids + } + return self.cache.children_multiple(node_ids) + + def get_atomic_cross_edges( + self, l2_ids: typing.Iterable, *, raw_only=False + ) -> typing.Dict[np.uint64, typing.Dict[int, typing.Iterable]]: + """Returns cross edges for level 2 IDs.""" + if raw_only or not self.cache: + node_edges_d_d = self.client.read_nodes( + node_ids=l2_ids, + properties=[ + attributes.Connectivity.CrossChunkEdge[l] + for l in range(2, max(3, self.meta.layer_count)) + ], + ) + result = {} + for id_ in l2_ids: + try: + result[id_] = { + prop.index: val[0].value.copy() + for prop, val in node_edges_d_d[id_].items() + } + except KeyError: + result[id_] = {} + return result + return self.cache.atomic_cross_edges_multiple(l2_ids) + + def get_cross_chunk_edges( + self, node_ids: typing.Iterable, uplift=True, all_layers=False + ) -> typing.Dict[np.uint64, typing.Dict[int, typing.Iterable]]: + """ + Cross chunk edges for `node_id` at `node_layer`. + The edges are between node IDs at the `node_layer`, not atomic cross edges. + Returns dict {layer_id: cross_edges} + The first layer (>= `node_layer`) with atleast one cross chunk edge. + For current use-cases, other layers are not relevant. + + For performance, only children that lie along chunk boundary are considered. + Cross edges that belong to inner level 2 IDs are subsumed within the chunk. + This is because cross edges are stored only in level 2 IDs. + """ + result = {} + node_ids = np.array(node_ids, dtype=basetypes.NODE_ID) + if not node_ids.size: + return result + + node_l2ids_d = {} + layers_ = self.get_chunk_layers(node_ids) + for l in set(layers_): + node_l2ids_d.update(self._get_bounding_l2_children(node_ids[layers_ == l])) + l2_edges_d_d = self.get_atomic_cross_edges( + np.concatenate(list(node_l2ids_d.values())) + ) + for node_id in node_ids: + l2_edges_ds = [l2_edges_d_d[l2_id] for l2_id in node_l2ids_d[node_id]] + if all_layers: + result[node_id] = edge_utils.concatenate_cross_edge_dicts(l2_edges_ds) + else: + result[node_id] = self._get_min_layer_cross_edges( + node_id, l2_edges_ds, uplift=uplift + ) + return result + + def _get_min_layer_cross_edges( + self, + node_id: basetypes.NODE_ID, + l2id_atomic_cross_edges_ds: typing.Iterable, + uplift=True, + ) -> typing.Dict[int, typing.Iterable]: + """ + Find edges at relevant min_layer >= node_layer. + `l2id_atomic_cross_edges_ds` is a list of atomic cross edges of + level 2 IDs that are descendants of `node_id`. + """ + min_layer, edges = edge_utils.filter_min_layer_cross_edges_multiple( + self.meta, l2id_atomic_cross_edges_ds, self.get_chunk_layer(node_id) + ) + if self.get_chunk_layer(node_id) < min_layer: + # cross edges irrelevant + return {self.get_chunk_layer(node_id): types.empty_2d} + if not uplift: + return {min_layer: edges} + node_root_id = node_id + node_root_id = self.get_root(node_id, stop_layer=min_layer, ceil=False) + edges[:, 0] = node_root_id + edges[:, 1] = self.get_roots(edges[:, 1], stop_layer=min_layer, ceil=False) + return {min_layer: np.unique(edges, axis=0) if edges.size else types.empty_2d} + + def get_roots( + self, + node_ids: typing.Sequence[np.uint64], + *, + assert_roots: bool = False, + time_stamp: typing.Optional[datetime.datetime] = None, + stop_layer: int = None, + ceil: bool = True, + fail_to_zero: bool = False, + n_tries: int = 1, + ) -> typing.Union[np.ndarray, typing.Dict[int, np.ndarray]]: + """ + Returns node IDs at the root_layer/ <= stop_layer. + Use `assert_roots=True` to ensure returned IDs are at root level. + When `assert_roots=False`, returns highest available IDs and + cases where there are no root IDs are silently ignored. + """ + time_stamp = misc_utils.get_valid_timestamp(time_stamp) + stop_layer = self.meta.layer_count if not stop_layer else stop_layer + assert stop_layer <= self.meta.layer_count + layer_mask = np.ones(len(node_ids), dtype=bool) + + for _ in range(n_tries): + chunk_layers = self.get_chunk_layers(node_ids) + layer_mask[chunk_layers >= stop_layer] = False + layer_mask[node_ids == 0] = False + + parent_ids = np.array(node_ids, dtype=basetypes.NODE_ID) + for _ in range(int(stop_layer + 1)): + filtered_ids = parent_ids[layer_mask] + unique_ids, inverse = np.unique(filtered_ids, return_inverse=True) + temp_ids = self.get_parents( + unique_ids, time_stamp=time_stamp, fail_to_zero=fail_to_zero + ) + if not temp_ids.size: + break + else: + temp_ids_i = temp_ids[inverse] + new_layer_mask = layer_mask.copy() + new_layer_mask[new_layer_mask] = ( + self.get_chunk_layers(temp_ids_i) < stop_layer + ) + if not ceil: + rev_m = self.get_chunk_layers(temp_ids_i) > stop_layer + temp_ids_i[rev_m] = filtered_ids[rev_m] + + parent_ids[layer_mask] = temp_ids_i + layer_mask = new_layer_mask + + if np.all(~layer_mask): + if assert_roots: + assert not np.any( + self.get_chunk_layers(parent_ids) + < self.meta.layer_count + ), "roots not found for some IDs" + return parent_ids + + if not ceil and np.all( + self.get_chunk_layers(parent_ids[parent_ids != 0]) >= stop_layer + ): + if assert_roots: + assert not np.any( + self.get_chunk_layers(parent_ids) < self.meta.layer_count + ), "roots not found for some IDs" + return parent_ids + elif ceil: + if assert_roots: + assert not np.any( + self.get_chunk_layers(parent_ids) < self.meta.layer_count + ), "roots not found for some IDs" + return parent_ids + else: + time.sleep(0.5) + if assert_roots: + assert not np.any( + self.get_chunk_layers(parent_ids) < self.meta.layer_count + ), "roots not found for some IDs" + return parent_ids + + def get_root( + self, + node_id: np.uint64, + *, + time_stamp: typing.Optional[datetime.datetime] = None, + get_all_parents: bool = False, + stop_layer: int = None, + ceil: bool = True, + n_tries: int = 1, + ) -> typing.Union[typing.List[np.uint64], np.uint64]: + """Takes a node id and returns the associated agglomeration ids.""" + time_stamp = misc_utils.get_valid_timestamp(time_stamp) + parent_id = node_id + all_parent_ids = [] + stop_layer = self.meta.layer_count if not stop_layer else stop_layer + if self.get_chunk_layer(parent_id) == stop_layer: + return ( + np.array([node_id], dtype=basetypes.NODE_ID) + if get_all_parents + else node_id + ) + + for _ in range(n_tries): + parent_id = node_id + for _ in range(self.get_chunk_layer(node_id), int(stop_layer + 1)): + temp_parent_id = self.get_parent(parent_id, time_stamp=time_stamp) + if temp_parent_id is None: + break + else: + parent_id = temp_parent_id + + if self.get_chunk_layer(parent_id) >= stop_layer: + if self.get_chunk_layer(parent_id) == stop_layer: + all_parent_ids.append(parent_id) + elif ceil: + all_parent_ids.append(parent_id) + break + else: + all_parent_ids.append(parent_id) + + if self.get_chunk_layer(parent_id) >= stop_layer: + break + else: + time.sleep(0.5) + + if self.get_chunk_layer(parent_id) < stop_layer: + raise exceptions.ChunkedGraphError( + f"Cannot find root id {node_id}, {stop_layer}, {time_stamp}" + ) + + if get_all_parents: + return np.array(all_parent_ids, dtype=basetypes.NODE_ID) + else: + if len(all_parent_ids) == 0: + return node_id + else: + return all_parent_ids[-1] + + def is_latest_roots( + self, + root_ids: typing.Iterable, + time_stamp: typing.Optional[datetime.datetime] = None, + ) -> typing.Iterable: + """Determines whether root ids are superseded.""" + time_stamp = misc_utils.get_valid_timestamp(time_stamp) + + row_dict = self.client.read_nodes( + node_ids=root_ids, + properties=[attributes.Hierarchy.Child, attributes.Hierarchy.NewParent], + end_time=time_stamp, + ) + + if len(row_dict) == 0: + return np.zeros(len(root_ids), dtype=bool) + + latest_roots = [ + k for k, v in row_dict.items() if not attributes.Hierarchy.NewParent in v + ] + return np.isin(root_ids, latest_roots) + + def get_all_parents_dict( + self, + node_id: basetypes.NODE_ID, + *, + time_stamp: typing.Optional[datetime.datetime] = None, + ) -> typing.Dict: + """Takes a node id and returns all parents up to root.""" + parent_ids = self.get_root( + node_id=node_id, time_stamp=time_stamp, get_all_parents=True + ) + return dict(zip(self.get_chunk_layers(parent_ids), parent_ids)) + + def get_subgraph( + self, + node_id_or_ids: typing.Union[np.uint64, typing.Iterable], + bbox: typing.Optional[typing.Sequence[typing.Sequence[int]]] = None, + bbox_is_coordinate: bool = False, + return_layers: typing.List = [2], + nodes_only: bool = False, + edges_only: bool = False, + leaves_only: bool = False, + return_flattened: bool = False, + ) -> typing.Tuple[typing.Dict, typing.Dict, Edges]: + """ + Generic subgraph method. + """ + from .subgraph import get_subgraph_nodes + from .subgraph import get_subgraph_edges_and_leaves + + if nodes_only: + return get_subgraph_nodes( + self, + node_id_or_ids, + bbox, + bbox_is_coordinate, + return_layers, + return_flattened=return_flattened, + ) + return get_subgraph_edges_and_leaves( + self, node_id_or_ids, bbox, bbox_is_coordinate, edges_only, leaves_only + ) + + def get_subgraph_nodes( + self, + node_id_or_ids: typing.Union[np.uint64, typing.Iterable], + bbox: typing.Optional[typing.Sequence[typing.Sequence[int]]] = None, + bbox_is_coordinate: bool = False, + return_layers: typing.List = [2], + serializable: bool = False, + return_flattened: bool = False, + ) -> typing.Tuple[typing.Dict, typing.Dict, Edges]: + """ + Get the children of `node_ids` that are at each of + return_layers within the specified bounding box. + """ + from .subgraph import get_subgraph_nodes + + return get_subgraph_nodes( + self, + node_id_or_ids, + bbox, + bbox_is_coordinate, + return_layers, + serializable=serializable, + return_flattened=return_flattened, + ) + + def get_subgraph_edges( + self, + node_id_or_ids: typing.Union[np.uint64, typing.Iterable], + bbox: typing.Optional[typing.Sequence[typing.Sequence[int]]] = None, + bbox_is_coordinate: bool = False, + ): + """ + Get the atomic edges of the `node_ids` within the specified bounding box. + """ + from .subgraph import get_subgraph_edges_and_leaves + + return get_subgraph_edges_and_leaves( + self, node_id_or_ids, bbox, bbox_is_coordinate, True, False + ) + + def get_subgraph_leaves( + self, + node_id_or_ids: typing.Union[np.uint64, typing.Iterable], + bbox: typing.Optional[typing.Sequence[typing.Sequence[int]]] = None, + bbox_is_coordinate: bool = False, + ): + """ + Get the supervoxels of the `node_ids` within the specified bounding box. + """ + from .subgraph import get_subgraph_edges_and_leaves + + return get_subgraph_edges_and_leaves( + self, node_id_or_ids, bbox, bbox_is_coordinate, False, True + ) + + def get_fake_edges( + self, chunk_ids: np.ndarray, time_stamp: datetime.datetime = None + ) -> typing.Dict: + result = {} + fake_edges_d = self.client.read_nodes( + node_ids=chunk_ids, + properties=attributes.Connectivity.FakeEdges, + end_time=time_stamp, + end_time_inclusive=True, + fake_edges=True, + ) + for id_, val in fake_edges_d.items(): + edges = np.concatenate( + [np.array(e.value, dtype=basetypes.NODE_ID) for e in val] + ) + result[id_] = Edges(edges[:, 0], edges[:, 1], fake_edges=True) + return result + + def get_l2_agglomerations( + self, level2_ids: np.ndarray, edges_only: bool = False + ) -> typing.Tuple[typing.Dict[int, types.Agglomeration], np.ndarray]: + """ + Children of Level 2 Node IDs and edges. + Edges are read from cloud storage. + """ + from itertools import chain + from functools import reduce + from .misc import get_agglomerations + + chunk_ids = np.unique(self.get_chunk_ids_from_node_ids(level2_ids)) + # google does not provide a storage emulator at the moment + # this is an ugly hack to avoid permission issues in tests + # find a better way to test + edges_d = {} + if self.mock_edges is None: + edges_d = self.read_chunk_edges(chunk_ids) + + fake_edges = self.get_fake_edges(chunk_ids) + all_chunk_edges = reduce( + lambda x, y: x + y, + chain(edges_d.values(), fake_edges.values()), + Edges([], []), + ) + + if edges_only: + if self.mock_edges is not None: + all_chunk_edges = self.mock_edges.get_pairs() + else: + all_chunk_edges = all_chunk_edges.get_pairs() + supervoxels = self.get_children(level2_ids, flatten=True) + mask0 = np.in1d(all_chunk_edges[:, 0], supervoxels) + mask1 = np.in1d(all_chunk_edges[:, 1], supervoxels) + return all_chunk_edges[mask0 & mask1] + + l2id_children_d = self.get_children(level2_ids) + sv_parent_d = {} + for l2id in l2id_children_d: + svs = l2id_children_d[l2id] + sv_parent_d.update(dict(zip(svs.tolist(), [l2id] * len(svs)))) + + in_edges, out_edges, cross_edges = edge_utils.categorize_edges_v2( + self.meta, + all_chunk_edges, + sv_parent_d + ) + + agglomeration_d = get_agglomerations( + l2id_children_d, in_edges, out_edges, cross_edges, sv_parent_d + ) + return ( + agglomeration_d, + (self.mock_edges,) + if self.mock_edges is not None + else (in_edges, out_edges, cross_edges), + ) + + def get_node_timestamps( + self, node_ids: typing.Sequence[np.uint64], return_numpy=True + ) -> typing.Iterable: + """ + The timestamp of the children column can be assumed + to be the timestamp at which the node ID was created. + """ + children = self.client.read_nodes( + node_ids=node_ids, properties=attributes.Hierarchy.Child + ) + + if not children: + if return_numpy: + return np.array([], dtype=np.datetime64) + return [] + if return_numpy: + return np.array( + [children[x][0].timestamp for x in node_ids], dtype=np.datetime64 + ) + return [children[x][0].timestamp for x in node_ids] + + # OPERATIONS + def add_edges( + self, + user_id: str, + atomic_edges: typing.Sequence[np.uint64], + *, + affinities: typing.Sequence[np.float32] = None, + source_coords: typing.Sequence[int] = None, + sink_coords: typing.Sequence[int] = None, + allow_same_segment_merge: typing.Optional[bool] = False, + ) -> operation.GraphEditOperation.Result: + """ + Adds an edge to the chunkedgraph + Multi-user safe through locking of the root node + This function acquires a lock and ensures that it still owns the + lock before executing the write. + :return: GraphEditOperation.Result + """ + return operation.MergeOperation( + self, + user_id=user_id, + added_edges=atomic_edges, + affinities=affinities, + source_coords=source_coords, + sink_coords=sink_coords, + allow_same_segment_merge=allow_same_segment_merge, + ).execute() + + def remove_edges( + self, + user_id: str, + *, + atomic_edges: typing.Sequence[typing.Tuple[np.uint64, np.uint64]] = None, + source_ids: typing.Sequence[np.uint64] = None, + sink_ids: typing.Sequence[np.uint64] = None, + source_coords: typing.Sequence[typing.Sequence[int]] = None, + sink_coords: typing.Sequence[typing.Sequence[int]] = None, + mincut: bool = True, + path_augment: bool = True, + disallow_isolating_cut: bool = True, + bb_offset: typing.Tuple[int, int, int] = (240, 240, 24), + ) -> operation.GraphEditOperation.Result: + """ + Removes edges - either directly or after applying a mincut + Multi-user safe through locking of the root node + This function acquires a lock and ensures that it still owns the + lock before executing the write. + :param atomic_edges: list of 2 uint64 + :param bb_offset: list of 3 ints + [x, y, z] bounding box padding beyond box spanned by coordinates + :return: GraphEditOperation.Result + """ + source_ids = [source_ids] if np.isscalar(source_ids) else source_ids + sink_ids = [sink_ids] if np.isscalar(sink_ids) else sink_ids + if mincut: + return operation.MulticutOperation( + self, + user_id=user_id, + source_ids=source_ids, + sink_ids=sink_ids, + source_coords=source_coords, + sink_coords=sink_coords, + bbox_offset=bb_offset, + path_augment=path_augment, + disallow_isolating_cut=disallow_isolating_cut, + ).execute() + + if not atomic_edges: + # Shim - can remove this check once all functions call the split properly/directly + if len(source_ids) != len(sink_ids): + raise exceptions.PreconditionError( + "Split operation require the same number of source and sink IDs" + ) + atomic_edges = np.array( + [source_ids, sink_ids], dtype=basetypes.NODE_ID + ).transpose() + return operation.SplitOperation( + self, + user_id=user_id, + removed_edges=atomic_edges, + source_coords=source_coords, + sink_coords=sink_coords, + ).execute() + + def undo_operation( + self, user_id: str, operation_id: np.uint64 + ) -> operation.GraphEditOperation.Result: + """Applies the inverse of a previous GraphEditOperation + :param user_id: str + :param operation_id: operation_id to be inverted + :return: GraphEditOperation.Result + """ + return operation.GraphEditOperation.undo_operation( + self, + user_id=user_id, + operation_id=operation_id, + multicut_as_split=True, + ).execute() + + def redo_operation( + self, user_id: str, operation_id: np.uint64 + ) -> operation.GraphEditOperation.Result: + """Re-applies a previous GraphEditOperation + :param user_id: str + :param operation_id: operation_id to be repeated + :return: GraphEditOperation.Result + """ + return operation.GraphEditOperation.redo_operation( + self, + user_id=user_id, + operation_id=operation_id, + multicut_as_split=True, + ).execute() + + # PRIVATE + + def _get_bounding_chunk_ids( + self, + parent_chunk_ids: typing.Iterable, + unique: bool = False, + ) -> typing.Dict: + """ + Returns bounding chunk IDs at layers < parent_layer for all chunk IDs. + Dict[parent_chunk_id] = np.array(bounding_chunk_ids) + """ + parent_chunk_coords = self.get_chunk_coordinates_multiple(parent_chunk_ids) + parents_layer = self.get_chunk_layer(parent_chunk_ids[0]) + chunk_id_bchunk_ids_d = {} + for i, chunk_id in enumerate(parent_chunk_ids): + if chunk_id in chunk_id_bchunk_ids_d: + # `parent_chunk_ids` can have duplicates + # avoid redundant calculations + continue + parent_coord = parent_chunk_coords[i] + chunk_ids = [types.empty_1d] + for child_layer in range(2, parents_layer): + bcoords = chunk_utils.get_bounding_children_chunks( + self.meta, + parents_layer, + parent_coord, + child_layer, + return_unique=False, + ) + bchunks_ids = chunk_utils.get_chunk_ids_from_coords( + self.meta, child_layer, bcoords + ) + chunk_ids.append(bchunks_ids) + chunk_ids = np.concatenate(chunk_ids) + if unique: + chunk_ids = np.unique(chunk_ids) + chunk_id_bchunk_ids_d[chunk_id] = chunk_ids + return chunk_id_bchunk_ids_d + + def _get_bounding_l2_children(self, parents: typing.Iterable) -> typing.Dict: + parent_chunk_ids = self.get_chunk_ids_from_node_ids(parents) + chunk_id_bchunk_ids_d = self._get_bounding_chunk_ids( + parent_chunk_ids, unique=len(parents) >= 200 + ) + + parent_descendants_d = { + _id: np.array([_id], dtype=basetypes.NODE_ID) for _id in parents + } + descendants_all = np.concatenate(list(parent_descendants_d.values())) + descendants_layers = self.get_chunk_layers(descendants_all) + layer_mask = descendants_layers > 2 + descendants_all = descendants_all[layer_mask] + + while descendants_all.size: + descendant_children_d = self.get_children(descendants_all) + for i, parent_id in enumerate(parents): + _descendants = parent_descendants_d[parent_id] + _layers = self.get_chunk_layers(_descendants) + _l2mask = _layers == 2 + descendants = [_descendants[_l2mask]] + for child in _descendants[~_l2mask]: + descendants.append(descendant_children_d[child]) + descendants = np.concatenate(descendants) + chunk_ids = self.get_chunk_ids_from_node_ids(descendants) + bchunk_ids = chunk_id_bchunk_ids_d[parent_chunk_ids[i]] + bounding_descendants = descendants[np.in1d(chunk_ids, bchunk_ids)] + parent_descendants_d[parent_id] = bounding_descendants + + descendants_all = np.concatenate(list(parent_descendants_d.values())) + descendants_layers = self.get_chunk_layers(descendants_all) + layer_mask = descendants_layers > 2 + descendants_all = descendants_all[layer_mask] + return parent_descendants_d + + # HELPERS / WRAPPERS + + def is_root(self, node_id: basetypes.NODE_ID) -> bool: + return self.get_chunk_layer(node_id) == self.meta.layer_count + + def get_serialized_info(self): + return { + "graph_id": self.meta.graph_config.ID_PREFIX + self.meta.graph_config.ID + } + + def get_node_id( + self, + segment_id: np.uint64, + chunk_id: typing.Optional[np.uint64] = None, + layer: typing.Optional[int] = None, + x: typing.Optional[int] = None, + y: typing.Optional[int] = None, + z: typing.Optional[int] = None, + ) -> np.uint64: + return id_helpers.get_node_id( + self.meta, segment_id, chunk_id=chunk_id, layer=layer, x=x, y=y, z=z + ) + + def get_segment_id(self, node_id: basetypes.NODE_ID): + return id_helpers.get_segment_id(self.meta, node_id) + + def get_segment_id_limit(self, node_or_chunk_id: basetypes.NODE_ID): + return id_helpers.get_segment_id_limit(self.meta, node_or_chunk_id) + + def get_chunk_layer(self, node_or_chunk_id: basetypes.NODE_ID): + return chunk_utils.get_chunk_layer(self.meta, node_or_chunk_id) + + def get_chunk_layers(self, node_or_chunk_ids: typing.Sequence): + return chunk_utils.get_chunk_layers(self.meta, node_or_chunk_ids) + + def get_chunk_coordinates(self, node_or_chunk_id: basetypes.NODE_ID): + return chunk_utils.get_chunk_coordinates(self.meta, node_or_chunk_id) + + def get_chunk_coordinates_multiple(self, node_or_chunk_ids: typing.Sequence): + node_or_chunk_ids = np.array(node_or_chunk_ids, dtype=basetypes.NODE_ID) + layers = self.get_chunk_layers(node_or_chunk_ids) + assert np.all(layers == layers[0]), "All IDs must have the same layer." + return chunk_utils.get_chunk_coordinates_multiple(self.meta, node_or_chunk_ids) + + def get_chunk_id( + self, + node_id: basetypes.NODE_ID = None, + layer: typing.Optional[int] = None, + x: typing.Optional[int] = 0, + y: typing.Optional[int] = 0, + z: typing.Optional[int] = 0, + ): + return chunk_utils.get_chunk_id( + self.meta, node_id=node_id, layer=layer, x=x, y=y, z=z + ) + + def get_chunk_ids_from_node_ids(self, node_ids: typing.Sequence): + return chunk_utils.get_chunk_ids_from_node_ids(self.meta, node_ids) + + def get_children_chunk_ids(self, node_or_chunk_id: basetypes.NODE_ID): + return chunk_hierarchy.get_children_chunk_ids(self.meta, node_or_chunk_id) + + def get_parent_chunk_id( + self, node_or_chunk_id: basetypes.NODE_ID, parent_layer: int = None + ): + if not parent_layer: + parent_layer = self.get_chunk_layer(node_or_chunk_id) + 1 + return chunk_hierarchy.get_parent_chunk_id( + self.meta, node_or_chunk_id, parent_layer + ) + + def get_parent_chunk_ids(self, node_or_chunk_id: basetypes.NODE_ID): + return chunk_hierarchy.get_parent_chunk_ids(self.meta, node_or_chunk_id) + + def get_parent_chunk_id_dict(self, node_or_chunk_id: basetypes.NODE_ID): + return chunk_hierarchy.get_parent_chunk_id_dict(self.meta, node_or_chunk_id) + + def get_cross_chunk_edges_layer(self, cross_edges: typing.Iterable): + return edge_utils.get_cross_chunk_edges_layer(self.meta, cross_edges) + + def read_chunk_edges(self, chunk_ids: typing.Iterable) -> typing.Dict: + from ..io.edges import get_chunk_edges + + return get_chunk_edges( + self.meta.data_source.EDGES, + self.get_chunk_coordinates_multiple(chunk_ids), + ) + + def get_proofread_root_ids( + self, + start_time: typing.Optional[datetime.datetime] = None, + end_time: typing.Optional[datetime.datetime] = None, + ): + from .misc import get_proofread_root_ids + + return get_proofread_root_ids(self, start_time, end_time) + + def get_earliest_timestamp(self): + from datetime import timedelta + + for op_id in range(100): + _, timestamp = self.client.read_log_entry(op_id) + if timestamp is not None: + return timestamp - timedelta(milliseconds=500) diff --git a/pychunkedgraph/creator/__init__.py b/pychunkedgraph/graph/chunks/__init__.py similarity index 100% rename from pychunkedgraph/creator/__init__.py rename to pychunkedgraph/graph/chunks/__init__.py diff --git a/pychunkedgraph/graph/chunks/atomic.py b/pychunkedgraph/graph/chunks/atomic.py new file mode 100644 index 000000000..e3de065ff --- /dev/null +++ b/pychunkedgraph/graph/chunks/atomic.py @@ -0,0 +1,65 @@ +from typing import List +from typing import Sequence +from itertools import product + +import numpy as np + +from .utils import get_bounding_children_chunks +from ..meta import ChunkedGraphMeta +from ..utils.generic import get_valid_timestamp +from ..utils import basetypes + + +def get_touching_atomic_chunks( + chunkedgraph_meta: ChunkedGraphMeta, + layer: int, + chunk_coords: Sequence[int], + include_both=False, +) -> List: + """get atomic chunk coordinates along touching faces of children chunks of a parent chunk""" + chunk_coords = np.array(chunk_coords, dtype=int) + touching_atomic_chunks = [] + + # atomic chunk count along one dimension + atomic_chunk_count = chunkedgraph_meta.graph_config.FANOUT ** (layer - 2) + layer2_chunk_bounds = chunkedgraph_meta.layer_chunk_bounds[2] + + chunk_offset = chunk_coords * atomic_chunk_count + mid = (atomic_chunk_count // 2) - 1 + + # TODO (akhileshh) convert this for loop to numpy + # relevant chunks along touching planes at center + for axis_1, axis_2 in product(*[range(atomic_chunk_count)] * 2): + # x-y plane + chunk_1 = chunk_offset + np.array((axis_1, axis_2, mid)) + touching_atomic_chunks.append(chunk_1) + # x-z plane + chunk_1 = chunk_offset + np.array((axis_1, mid, axis_2)) + touching_atomic_chunks.append(chunk_1) + # y-z plane + chunk_1 = chunk_offset + np.array((mid, axis_1, axis_2)) + touching_atomic_chunks.append(chunk_1) + + if include_both: + chunk_2 = chunk_offset + np.array((axis_1, axis_2, mid + 1)) + touching_atomic_chunks.append(chunk_2) + + chunk_2 = chunk_offset + np.array((axis_1, mid + 1, axis_2)) + touching_atomic_chunks.append(chunk_2) + + chunk_2 = chunk_offset + np.array((mid + 1, axis_1, axis_2)) + touching_atomic_chunks.append(chunk_2) + + chunks = np.array(touching_atomic_chunks, dtype=int) + mask = np.all(chunks < layer2_chunk_bounds, axis=1) + result = chunks[mask] + if result.size: + return np.unique(result, axis=0) + return [] + + +def get_bounding_atomic_chunks( + chunkedgraph_meta: ChunkedGraphMeta, layer: int, chunk_coords: Sequence[int] +) -> List: + """Atomic chunk coordinates along the boundary of a chunk""" + return get_bounding_children_chunks(chunkedgraph_meta, layer, chunk_coords, 2) diff --git a/pychunkedgraph/graph/chunks/hierarchy.py b/pychunkedgraph/graph/chunks/hierarchy.py new file mode 100644 index 000000000..32d6029ee --- /dev/null +++ b/pychunkedgraph/graph/chunks/hierarchy.py @@ -0,0 +1,92 @@ +from itertools import product +from typing import Sequence +from typing import Iterable + +import numpy as np + +from . import utils +from ..meta import ChunkedGraphMeta + + +def get_children_chunk_coords( + meta: ChunkedGraphMeta, layer: int, chunk_coords: Sequence[int] +) -> Iterable: + """ + Returns coordiantes of children chunks. + Filters out chunks that are outside the boundary of the dataset. + """ + chunk_coords = np.array(chunk_coords, dtype=int) + children_layer = layer - 1 + layer_boundaries = meta.layer_chunk_bounds[children_layer] + children_coords = [] + + for dcoord in product(*[range(meta.graph_config.FANOUT)] * 3): + dcoord = np.array(dcoord, dtype=int) + child_coords = chunk_coords * meta.graph_config.FANOUT + dcoord + check_bounds = np.less(child_coords, layer_boundaries) + if np.all(check_bounds): + children_coords.append(child_coords) + return np.array(children_coords) + + +def get_children_chunk_ids( + meta: ChunkedGraphMeta, node_or_chunk_id: np.uint64 +) -> np.ndarray: + """Calculates the ids of the children chunks in the next lower layer.""" + x, y, z = utils.get_chunk_coordinates(meta, node_or_chunk_id) + layer = utils.get_chunk_layer(meta, node_or_chunk_id) + + if layer == 1: + return np.array([]) + elif layer == 2: + return np.array([utils.get_chunk_id(meta, layer=layer, x=x, y=y, z=z)]) + else: + children_coords = get_children_chunk_coords(meta, layer, (x, y, z)) + children_chunk_ids = [] + for (x, y, z) in children_coords: + children_chunk_ids.append( + utils.get_chunk_id(meta, layer=layer - 1, x=x, y=y, z=z) + ) + return np.array(children_chunk_ids) + + +def get_parent_chunk_id( + meta: ChunkedGraphMeta, node_or_chunk_id: np.uint64, parent_layer: int +) -> np.ndarray: + """Parent chunk ID at given layer.""" + node_layer = utils.get_chunk_layer(meta, node_or_chunk_id) + coord = utils.get_chunk_coordinates(meta, node_or_chunk_id) + for _ in range(node_layer, parent_layer): + coord = coord // meta.graph_config.FANOUT + x, y, z = coord + return utils.get_chunk_id(meta, layer=parent_layer, x=x, y=y, z=z) + + +def get_parent_chunk_ids( + meta: ChunkedGraphMeta, node_or_chunk_id: np.uint64 +) -> np.ndarray: + """Creates list of chunk parent ids (upto highest layer).""" + parent_chunk_layers = range( + utils.get_chunk_layer(meta, node_or_chunk_id) + 1, meta.layer_count + 1 + ) + chunk_coord = utils.get_chunk_coordinates(meta, node_or_chunk_id) + parent_chunk_ids = [utils.get_chunk_id(meta, node_or_chunk_id)] + for layer in parent_chunk_layers: + chunk_coord = chunk_coord // meta.graph_config.FANOUT + x, y, z = chunk_coord + parent_chunk_ids.append(utils.get_chunk_id(meta, layer=layer, x=x, y=y, z=z)) + return np.array(parent_chunk_ids, dtype=np.uint64) + + +def get_parent_chunk_id_dict(meta: ChunkedGraphMeta, node_or_chunk_id: np.uint64): + """ + Returns dict of {layer: parent_chunk_id} + (Convenience function) + """ + layer = utils.get_chunk_layer(meta, node_or_chunk_id) + return dict( + zip( + range(layer, meta.layer_count + 1), + get_parent_chunk_ids(meta, node_or_chunk_id), + ) + ) diff --git a/pychunkedgraph/graph/chunks/utils.py b/pychunkedgraph/graph/chunks/utils.py new file mode 100644 index 000000000..dc895bde4 --- /dev/null +++ b/pychunkedgraph/graph/chunks/utils.py @@ -0,0 +1,235 @@ +# pylint: disable=invalid-name, missing-docstring + +from typing import List +from typing import Union +from typing import Optional +from typing import Sequence +from typing import Iterable + +import numpy as np + +def get_chunks_boundary(voxel_boundary, chunk_size) -> np.ndarray: + """returns number of chunks in each dimension""" + return np.ceil((voxel_boundary / chunk_size)).astype(int) + + +def normalize_bounding_box( + meta, + bounding_box: Optional[Sequence[Sequence[int]]], + bbox_is_coordinate: bool, +) -> Union[Sequence[Sequence[int]], None]: + if bounding_box is None: + return None + + bbox = bounding_box.copy() + if bbox_is_coordinate: + bbox[0] = _get_chunk_coordinates_from_vol_coordinates( + meta, + bbox[0][0], + bbox[0][1], + bbox[0][2], + resolution=meta.resolution, + ceil=False, + ) + bbox[1] = _get_chunk_coordinates_from_vol_coordinates( + meta, + bbox[1][0], + bbox[1][1], + bbox[1][2], + resolution=meta.resolution, + ceil=True, + ) + return np.array(bbox, dtype=int) + + +def get_chunk_layer(meta, node_or_chunk_id: np.uint64) -> int: + """ Extract Layer from Node ID or Chunk ID """ + return int(int(node_or_chunk_id) >> 64 - meta.graph_config.LAYER_ID_BITS) + + +def get_chunk_layers(meta, node_or_chunk_ids: Sequence[np.uint64]) -> np.ndarray: + """Extract Layers from Node IDs or Chunk IDs + :param node_or_chunk_ids: np.ndarray + :return: np.ndarray + """ + if len(node_or_chunk_ids) == 0: + return np.array([], dtype=int) + + layers = np.array(node_or_chunk_ids, dtype=int) + + layers1 = layers >> (64 - meta.graph_config.LAYER_ID_BITS) + # layers2 = np.vectorize(get_chunk_layer)(meta, node_or_chunk_ids) + # assert np.all(layers1 == layers2) + return layers1 + + +def get_chunk_coordinates(meta, node_or_chunk_id: np.uint64) -> np.ndarray: + """Extract X, Y and Z coordinate from Node ID or Chunk ID + :param node_or_chunk_id: np.uint64 + :return: Tuple(int, int, int) + """ + layer = get_chunk_layer(meta, node_or_chunk_id) + bits_per_dim = meta.bitmasks[layer] + + x_offset = 64 - meta.graph_config.LAYER_ID_BITS - bits_per_dim + y_offset = x_offset - bits_per_dim + z_offset = y_offset - bits_per_dim + + x = int(node_or_chunk_id) >> x_offset & 2 ** bits_per_dim - 1 + y = int(node_or_chunk_id) >> y_offset & 2 ** bits_per_dim - 1 + z = int(node_or_chunk_id) >> z_offset & 2 ** bits_per_dim - 1 + return np.array([x, y, z]) + + +def get_chunk_coordinates_multiple(meta, ids: np.ndarray) -> np.ndarray: + """ + Array version of get_chunk_coordinates. + Assumes all given IDs are in same layer. + """ + if not len(ids): + return np.array([]) + layer = get_chunk_layer(meta, ids[0]) + bits_per_dim = meta.bitmasks[layer] + + x_offset = 64 - meta.graph_config.LAYER_ID_BITS - bits_per_dim + y_offset = x_offset - bits_per_dim + z_offset = y_offset - bits_per_dim + + ids = np.array(ids, dtype=int) + X = ids >> x_offset & 2 ** bits_per_dim - 1 + Y = ids >> y_offset & 2 ** bits_per_dim - 1 + Z = ids >> z_offset & 2 ** bits_per_dim - 1 + return np.column_stack((X, Y, Z)) + + +def get_chunk_id( + meta, + node_id: Optional[np.uint64] = None, + layer: Optional[int] = None, + x: Optional[int] = None, + y: Optional[int] = None, + z: Optional[int] = None, +) -> np.uint64: + """(1) Extract Chunk ID from Node ID + (2) Build Chunk ID from Layer, X, Y and Z components + """ + assert node_id is not None or all(v is not None for v in [layer, x, y, z]) + if node_id is not None: + layer = get_chunk_layer(meta, node_id) + bits_per_dim = meta.bitmasks[layer] + + if node_id is not None: + chunk_offset = 64 - meta.graph_config.LAYER_ID_BITS - 3 * bits_per_dim + return np.uint64((int(node_id) >> chunk_offset) << chunk_offset) + return _compute_chunk_id(meta, layer, x, y, z) + + +def get_chunk_ids_from_coords(meta, layer: int, coords: np.ndarray): + result = np.zeros(len(coords), dtype=np.uint64) + s_bits_per_dim = meta.bitmasks[layer] + + layer_offset = 64 - meta.graph_config.LAYER_ID_BITS + x_offset = layer_offset - s_bits_per_dim + y_offset = x_offset - s_bits_per_dim + z_offset = y_offset - s_bits_per_dim + coords = np.array(coords, dtype=np.uint64) + + result |= layer << layer_offset + result |= coords[:, 0] << x_offset + result |= coords[:, 1] << y_offset + result |= coords[:, 2] << z_offset + return result + + +def get_chunk_ids_from_node_ids(meta, ids: Iterable[np.uint64]) -> np.ndarray: + """ Extract Chunk IDs from Node IDs""" + if len(ids) == 0: + return np.array([], dtype=np.uint64) + + bits_per_dims = np.array([meta.bitmasks[l] for l in get_chunk_layers(meta, ids)]) + offsets = 64 - meta.graph_config.LAYER_ID_BITS - 3 * bits_per_dims + + cids1 = np.array((np.array(ids, dtype=int) >> offsets) << offsets, dtype=np.uint64) + # cids2 = np.vectorize(get_chunk_id)(meta, ids) + # assert np.all(cids1 == cids2) + return cids1 + + +def _compute_chunk_id( + meta, + layer: int, + x: int, + y: int, + z: int, +) -> np.uint64: + s_bits_per_dim = meta.bitmasks[layer] + if not ( + x < 2 ** s_bits_per_dim and y < 2 ** s_bits_per_dim and z < 2 ** s_bits_per_dim + ): + raise ValueError( + f"Coordinate is out of range \ + layer: {layer} bits/dim {s_bits_per_dim}. \ + [{x}, {y}, {z}]; max = {2 ** s_bits_per_dim}." + ) + layer_offset = 64 - meta.graph_config.LAYER_ID_BITS + x_offset = layer_offset - s_bits_per_dim + y_offset = x_offset - s_bits_per_dim + z_offset = y_offset - s_bits_per_dim + return np.uint64( + layer << layer_offset | x << x_offset | y << y_offset | z << z_offset + ) + + +def _get_chunk_coordinates_from_vol_coordinates( + meta, + x: int, + y: int, + z: int, + resolution: Sequence[int], + ceil: bool = False, + layer: int = 1, +) -> np.ndarray: + """Translates volume coordinates to chunk_coordinates.""" + resolution = np.array(resolution) + scaling = np.array(meta.resolution / resolution, dtype=int) + + chunk_size = meta.graph_config.CHUNK_SIZE + x = (x / scaling[0] - meta.voxel_bounds[0, 0]) / chunk_size[0] + y = (y / scaling[1] - meta.voxel_bounds[1, 0]) / chunk_size[1] + z = (z / scaling[2] - meta.voxel_bounds[2, 0]) / chunk_size[2] + + x /= meta.graph_config.FANOUT ** (max(layer - 2, 0)) + y /= meta.graph_config.FANOUT ** (max(layer - 2, 0)) + z /= meta.graph_config.FANOUT ** (max(layer - 2, 0)) + + coords = np.array([x, y, z]) + if ceil: + coords = np.ceil(coords) + return coords.astype(int) + + +def get_bounding_children_chunks( + cg_meta, layer: int, chunk_coords: Sequence[int], children_layer, return_unique=True +) -> np.ndarray: + """Children chunk coordinates at given layer, along the boundary of a chunk""" + chunk_coords = np.array(chunk_coords, dtype=int) + chunks = [] + + # children chunk count along one dimension + chunks_count = cg_meta.graph_config.FANOUT ** (layer - children_layer) + chunk_offset = chunk_coords * chunks_count + x1, y1, z1 = chunk_offset + x2, y2, z2 = chunk_offset + chunks_count + + # https://stackoverflow.com/a/35608701/2683367 + f = lambda r1, r2, r3: np.array(np.meshgrid(r1, r2, r3), dtype=int).T.reshape(-1, 3) + chunks.append(f((x1, x2 - 1), range(y1, y2), range(z1, z2))) + chunks.append(f(range(x1, x2), (y1, y2 - 1), range(z1, z2))) + chunks.append(f(range(x1, x2), range(y1, y2), (z1, z2 - 1))) + + chunks = np.concatenate(chunks) + mask = np.all(chunks < cg_meta.layer_chunk_bounds[children_layer], axis=1) + result = chunks[mask] + if return_unique: + return np.unique(result, axis=0) if result.size else result + return result diff --git a/pychunkedgraph/graph/client/__init__.py b/pychunkedgraph/graph/client/__init__.py new file mode 100644 index 000000000..6e025bd35 --- /dev/null +++ b/pychunkedgraph/graph/client/__init__.py @@ -0,0 +1,44 @@ +""" +Sub packages/modules for backend storage clients +Currently supports Google Big Table + +A simple client needs to be able to create the graph, +store graph meta and to write and read node information. +Also needs locking support to prevent race conditions +when modifying root/parent nodes. + +In addition, clients with more features like generating unique IDs +and logging facilities can be implemented by inherting respective base classes. + +These methods are in separate classes because they are logically related. +This also makes it possible to have different backend storage solutions, +making it possible to use any unique features these solutions may provide. + +Please see `base.py` for more details. +""" + +from collections import namedtuple + +from .bigtable.client import Client as BigTableClient + + +_backend_clientinfo_fields = ("TYPE", "CONFIG") +_backend_clientinfo_defaults = (None, None) +BackendClientInfo = namedtuple( + "BackendClientInfo", + _backend_clientinfo_fields, + defaults=_backend_clientinfo_defaults, +) + + +def get_default_client_info(): + """ + Load client from env variables. + """ + + # TODO make dynamic after multiple platform support is added + from .bigtable import get_client_info as get_bigtable_client_info + + return BackendClientInfo( + CONFIG=get_bigtable_client_info(admin=True, read_only=False) + ) diff --git a/pychunkedgraph/graph/client/base.py b/pychunkedgraph/graph/client/base.py new file mode 100644 index 000000000..a66602a6a --- /dev/null +++ b/pychunkedgraph/graph/client/base.py @@ -0,0 +1,152 @@ +from abc import ABC +from abc import abstractmethod + + +class SimpleClient(ABC): + """ + Abstract class for interacting with backend data store where the chunkedgraph is stored. + Eg., BigTableClient for using big table as storage. + """ + + @abstractmethod + def create_graph(self) -> None: + """Initialize the graph and store associated meta.""" + + @abstractmethod + def add_graph_version(self, version): + """Add a version to the graph.""" + + @abstractmethod + def read_graph_version(self): + """Read stored graph version.""" + + @abstractmethod + def update_graph_meta(self, meta): + """Update stored graph meta.""" + + @abstractmethod + def read_graph_meta(self): + """Read stored graph meta.""" + + @abstractmethod + def read_nodes( + self, + start_id=None, + end_id=None, + node_ids=None, + properties=None, + start_time=None, + end_time=None, + end_time_inclusive=False, + ): + """ + Read nodes and their properties. + Accepts a range of node IDs or specific node IDs. + """ + + @abstractmethod + def read_node( + self, + node_id, + properties=None, + start_time=None, + end_time=None, + end_time_inclusive=False, + ): + """Read a single node and it's properties.""" + + @abstractmethod + def write_nodes(self, nodes): + """Writes/updates nodes (IDs along with properties).""" + + @abstractmethod + def lock_root(self, node_id, operation_id): + """Locks root node with operation_id to prevent race conditions.""" + + @abstractmethod + def lock_roots(self, node_ids, operation_id): + """Locks root nodes to prevent race conditions.""" + + @abstractmethod + def lock_root_indefinitely(self, node_id, operation_id): + """Locks root node with operation_id to prevent race conditions.""" + + @abstractmethod + def lock_roots_indefinitely(self, node_ids, operation_id): + """ + Locks root nodes indefinitely to prevent structural damage to graph. + This scenario is rare and needs asynchronous fix or inspection to unlock. + """ + + @abstractmethod + def unlock_root(self, node_id, operation_id): + """Unlocks root node that is locked with operation_id.""" + + @abstractmethod + def unlock_indefinitely_locked_root(self, node_id, operation_id): + """Unlocks root node that is indefinitely locked with operation_id.""" + + @abstractmethod + def renew_lock(self, node_id, operation_id): + """Renews existing node lock with operation_id for extended time.""" + + @abstractmethod + def renew_locks(self, node_ids, operation_id): + """Renews existing node locks with operation_id for extended time.""" + + @abstractmethod + def get_lock_timestamp(self, node_ids, operation_id): + """Reads timestamp from lock row to get a consistent timestamp.""" + + @abstractmethod + def get_consolidated_lock_timestamp(self, root_ids, operation_ids): + """Minimum of multiple lock timestamps.""" + + @abstractmethod + def get_compatible_timestamp(self, time_stamp): + """Datetime time stamp compatible with client's services.""" + + +class ClientWithIDGen(SimpleClient): + """ + Abstract class for client to backend data store that has support for generating IDs. + If not, something else can be used but these methods need to be implemented. + Eg., Big Table row cells can be used to generate unique IDs. + """ + + @abstractmethod + def create_node_ids(self, chunk_id): + """Generate a range of unique IDs in the chunk.""" + + @abstractmethod + def create_node_id(self, chunk_id): + """Generate a unique ID in the chunk.""" + + @abstractmethod + def get_max_node_id(self, chunk_id): + """Gets the current maximum node ID in the chunk.""" + + @abstractmethod + def create_operation_id(self): + """Generate a unique operation ID.""" + + @abstractmethod + def get_max_operation_id(self): + """Gets the current maximum operation ID.""" + + +class OperationLogger(ABC): + """ + Abstract class for interacting with backend data store where the operation logs are stored. + Eg., BigTableClient can be used to store logs in Google BigTable. + """ + + # TODO add functions for writing + + @abstractmethod + def read_log_entry(self, operation_id: int) -> None: + """Read log entry for a given operation ID.""" + + @abstractmethod + def read_log_entries(self, operation_ids) -> None: + """Read log entries for given operation IDs.""" diff --git a/pychunkedgraph/graph/client/bigtable/__init__.py b/pychunkedgraph/graph/client/bigtable/__init__.py new file mode 100644 index 000000000..b3dbd777b --- /dev/null +++ b/pychunkedgraph/graph/client/bigtable/__init__.py @@ -0,0 +1,49 @@ +from collections import namedtuple +from os import environ + +DEFAULT_PROJECT = "neuromancer-seung-import" +DEFAULT_INSTANCE = "pychunkedgraph" + +_bigtableconfig_fields = ( + "PROJECT", + "INSTANCE", + "ADMIN", + "READ_ONLY", + "CREDENTIALS", + "MAX_ROW_KEY_COUNT" +) +_bigtableconfig_defaults = ( + environ.get("BIGTABLE_PROJECT", DEFAULT_PROJECT), + environ.get("BIGTABLE_INSTANCE", DEFAULT_INSTANCE), + False, + True, + None, + 1000 +) +BigTableConfig = namedtuple( + "BigTableConfig", _bigtableconfig_fields, defaults=_bigtableconfig_defaults +) + + +def get_client_info( + project: str = None, + instance: str = None, + admin: bool = False, + read_only: bool = True, +): + """Helper function to load config from env.""" + _project = environ.get("BIGTABLE_PROJECT", DEFAULT_PROJECT) + if project: + _project = project + + _instance = environ.get("BIGTABLE_INSTANCE", DEFAULT_INSTANCE) + if instance: + _instance = instance + + kwargs = { + "PROJECT": _project, + "INSTANCE": _instance, + "ADMIN": admin, + "READ_ONLY": read_only, + } + return BigTableConfig(**kwargs) diff --git a/pychunkedgraph/graph/client/bigtable/client.py b/pychunkedgraph/graph/client/bigtable/client.py new file mode 100644 index 000000000..5b86826bd --- /dev/null +++ b/pychunkedgraph/graph/client/bigtable/client.py @@ -0,0 +1,860 @@ +# pylint: disable=invalid-name, missing-docstring, import-outside-toplevel, line-too-long, protected-access, arguments-differ, arguments-renamed, logging-fstring-interpolation + +import sys +import time +import typing +import logging +import datetime +from datetime import datetime + +import numpy as np +from multiwrapper import multiprocessing_utils as mu +from google.cloud import bigtable +from google.api_core.retry import Retry +from google.api_core.retry import if_exception_type +from google.api_core.exceptions import Aborted +from google.api_core.exceptions import DeadlineExceeded +from google.api_core.exceptions import ServiceUnavailable +from google.cloud.bigtable.table import Table +from google.cloud.bigtable.row_set import RowSet +from google.cloud.bigtable.row_data import PartialRowData +from google.cloud.bigtable.row_filters import RowFilter +from google.cloud.bigtable.column_family import MaxVersionsGCRule + +from . import utils +from . import BigTableConfig +from ..base import ClientWithIDGen +from ..base import OperationLogger +from ... import attributes +from ... import exceptions +from ...utils import basetypes +from ...utils.serializers import pad_node_id +from ...utils.serializers import serialize_key +from ...utils.serializers import serialize_uint64 +from ...utils.serializers import deserialize_uint64 +from ...meta import ChunkedGraphMeta +from ...utils.generic import get_valid_timestamp + + +class Client(bigtable.Client, ClientWithIDGen, OperationLogger): + def __init__( + self, + table_id: str, + config: BigTableConfig = BigTableConfig(), + graph_meta: ChunkedGraphMeta = None, + ): + if config.CREDENTIALS: + super(Client, self).__init__( + project=config.PROJECT, + read_only=config.READ_ONLY, + admin=config.ADMIN, + credentials=config.CREDENTIALS, + ) + else: + super(Client, self).__init__( + project=config.PROJECT, + read_only=config.READ_ONLY, + admin=config.ADMIN, + ) + self._instance = self.instance(config.INSTANCE) + self._table = self._instance.table(table_id) + + self.logger = logging.getLogger( + f"{config.PROJECT}/{config.INSTANCE}/{table_id}" + ) + self.logger.setLevel(logging.WARNING) + if not self.logger.handlers: + sh = logging.StreamHandler(sys.stdout) + sh.setLevel(logging.WARNING) + self.logger.addHandler(sh) + self._graph_meta = graph_meta + self._version = None + self._max_row_key_count = config.MAX_ROW_KEY_COUNT + + @property + def graph_meta(self): + return self._graph_meta + + def create_graph(self, meta: ChunkedGraphMeta, version: str) -> None: + """Initialize the graph and store associated meta.""" + if self._table.exists(): + raise ValueError(f"{self._table.table_id} already exists.") + self._table.create() + self._create_column_families() + self.add_graph_version(version) + self.update_graph_meta(meta) + + def add_graph_version(self, version: str): + assert self.read_graph_version() is None, "Graph has already been versioned." + self._version = version + row = self.mutate_row( + attributes.GraphVersion.key, + {attributes.GraphVersion.Version: version}, + ) + self.write([row]) + + def read_graph_version(self) -> str: + try: + row = self._read_byte_row(attributes.GraphVersion.key) + self._version = row[attributes.GraphVersion.Version][0].value + return self._version + except KeyError: + return None + + def _delete_meta(self): + # temprorary fix, use new column with GCRule for permanent fix + # delete existing meta before update, but compatibilty issues + meta_row = self._table.direct_row(attributes.GraphMeta.key) + meta_row.delete() + meta_row.commit() + + def update_graph_meta( + self, meta: ChunkedGraphMeta, overwrite: typing.Optional[bool] = False + ): + if overwrite: + self._delete_meta() + self._graph_meta = meta + row = self.mutate_row( + attributes.GraphMeta.key, + {attributes.GraphMeta.Meta: meta}, + ) + self.write([row]) + + def read_graph_meta(self) -> ChunkedGraphMeta: + row = self._read_byte_row(attributes.GraphMeta.key) + self._graph_meta = row[attributes.GraphMeta.Meta][0].value + return self._graph_meta + + def read_nodes( + self, + start_id=None, + end_id=None, + end_id_inclusive=False, + user_id=None, + node_ids=None, + properties=None, + start_time=None, + end_time=None, + end_time_inclusive: bool = False, + fake_edges: bool = False, + ): + """ + Read nodes and their properties. + Accepts a range of node IDs or specific node IDs. + """ + if node_ids is not None and len(node_ids) > self._max_row_key_count: + # bigtable reading is faster + # when all IDs in a block are within a range + node_ids = np.sort(node_ids) + rows = self._read_byte_rows( + start_key=serialize_uint64(start_id, fake_edges=fake_edges) + if start_id is not None + else None, + end_key=serialize_uint64(end_id, fake_edges=fake_edges) + if end_id is not None + else None, + end_key_inclusive=end_id_inclusive, + row_keys=( + serialize_uint64(node_id, fake_edges=fake_edges) for node_id in node_ids + ) + if node_ids is not None + else None, + columns=properties, + start_time=start_time, + end_time=end_time, + end_time_inclusive=end_time_inclusive, + user_id=user_id, + ) + return { + deserialize_uint64(row_key, fake_edges=fake_edges): data + for (row_key, data) in rows.items() + } + + def read_node( + self, + node_id: np.uint64, + properties: typing.Optional[ + typing.Union[typing.Iterable[attributes._Attribute], attributes._Attribute] + ] = None, + start_time: typing.Optional[datetime] = None, + end_time: typing.Optional[datetime] = None, + end_time_inclusive: bool = False, + fake_edges: bool = False, + ) -> typing.Union[ + typing.Dict[attributes._Attribute, typing.List[bigtable.row_data.Cell]], + typing.List[bigtable.row_data.Cell], + ]: + """Convenience function for reading a single node from Bigtable. + Arguments: + node_id {np.uint64} -- the NodeID of the row to be read. + Keyword Arguments: + columns {typing.Optional[typing.Union[typing.Iterable[attributes._Attribute], attributes._Attribute]]} -- + typing.Optional filtering by columns to speed up the query. If `columns` is a single + column (not iterable), the column key will be omitted from the result. + (default: {None}) + start_time {typing.Optional[datetime]} -- Ignore cells with timestamp before + `start_time`. If None, no lower bound. (default: {None}) + end_time {typing.Optional[datetime]} -- Ignore cells with timestamp after `end_time`. + If None, no upper bound. (default: {None}) + end_time_inclusive {bool} -- Whether or not `end_time` itself should be included in the + request, ignored if `end_time` is None. (default: {False}) + Returns: + typing.Union[typing.Dict[attributes._Attribute, typing.List[bigtable.row_data.Cell]], + typing.List[bigtable.row_data.Cell]] -- + Returns a mapping of columns to a typing.List of cells (one cell per timestamp). Each cell + has a `value` property, which returns the deserialized field, and a `timestamp` + property, which returns the timestamp as `datetime` object. + If only a single `attributes._Attribute` was requested, the typing.List of cells is returned + directly. + """ + return self._read_byte_row( + row_key=serialize_uint64(node_id, fake_edges=fake_edges), + columns=properties, + start_time=start_time, + end_time=end_time, + end_time_inclusive=end_time_inclusive, + ) + + def write_nodes(self, nodes, root_ids=None, operation_id=None): + """ + Writes/updates nodes (IDs along with properties) + by locking root nodes until changes are written. + """ + + def read_log_entry( + self, operation_id: np.uint64 + ) -> typing.Tuple[typing.Dict, datetime]: + log_record = self.read_node( + operation_id, properties=attributes.OperationLogs.all() + ) + if len(log_record) == 0: + return {}, None + try: + timestamp = log_record[attributes.OperationLogs.OperationTimeStamp][0].value + except KeyError: + timestamp = log_record[attributes.OperationLogs.RootID][0].timestamp + log_record.update((column, v[0].value) for column, v in log_record.items()) + return log_record, timestamp + + def read_log_entries( + self, + operation_ids: typing.Optional[typing.Iterable] = None, + user_id: typing.Optional[str] = None, + properties: typing.Optional[typing.Iterable[attributes._Attribute]] = None, + start_time: typing.Optional[datetime] = None, + end_time: typing.Optional[datetime] = None, + end_time_inclusive: bool = False, + ): + if properties is None: + properties = attributes.OperationLogs.all() + + if operation_ids is None: + logs_d = self.read_nodes( + start_id=np.uint64(0), + end_id=self.get_max_operation_id(), + end_id_inclusive=True, + user_id=user_id, + properties=properties, + start_time=start_time, + end_time=end_time, + end_time_inclusive=end_time_inclusive, + ) + else: + logs_d = self.read_nodes( + node_ids=operation_ids, + properties=properties, + start_time=start_time, + end_time=end_time, + end_time_inclusive=end_time_inclusive, + user_id=user_id, + ) + if not logs_d: + return {} + for operation_id in logs_d: + log_record = logs_d[operation_id] + try: + timestamp = log_record[attributes.OperationLogs.OperationTimeStamp][ + 0 + ].value + except KeyError: + timestamp = log_record[attributes.OperationLogs.RootID][0].timestamp + log_record.update((column, v[0].value) for column, v in log_record.items()) + log_record["timestamp"] = timestamp + return logs_d + + # Helpers + def write( + self, + rows: typing.Iterable[bigtable.row.DirectRow], + root_ids: typing.Optional[ + typing.Union[np.uint64, typing.Iterable[np.uint64]] + ] = None, + operation_id: typing.Optional[np.uint64] = None, + slow_retry: bool = True, + block_size: int = 2000, + ): + """Writes a list of mutated rows in bulk + WARNING: If contains the same row (same row_key) and column + key two times only the last one is effectively written to the BigTable + (even when the mutations were applied to different columns) + --> no versioning! + :param rows: list + list of mutated rows + :param root_ids: list if uint64 + :param operation_id: uint64 or None + operation_id (or other unique id) that *was* used to lock the root + the bulk write is only executed if the root is still locked with + the same id. + :param slow_retry: bool + :param block_size: int + """ + if slow_retry: + initial = 5 + else: + initial = 1 + + exception_types = (Aborted, DeadlineExceeded, ServiceUnavailable) + retry = Retry( + predicate=if_exception_type(exception_types), + initial=initial, + maximum=15.0, + multiplier=2.0, + deadline=self.graph_meta.graph_config.ROOT_LOCK_EXPIRY.seconds, + ) + + if root_ids is not None and operation_id is not None: + if isinstance(root_ids, int): + root_ids = [root_ids] + if not self.renew_locks(root_ids, operation_id): + raise exceptions.LockingError( + f"Root lock renewal failed: operation {operation_id}" + ) + + for i in range(0, len(rows), block_size): + status = self._table.mutate_rows(rows[i : i + block_size], retry=retry) + if not all(status): + raise exceptions.ChunkedGraphError( + f"Bulk write failed: operation {operation_id}" + ) + + def mutate_row( + self, + row_key: bytes, + val_dict: typing.Dict[attributes._Attribute, typing.Any], + time_stamp: typing.Optional[datetime] = None, + ) -> bigtable.row.Row: + """Mutates a single row (doesn't write to big table).""" + row = self._table.direct_row(row_key) + for column, value in val_dict.items(): + row.set_cell( + column_family_id=column.family_id, + column=column.key, + value=column.serialize(value), + timestamp=time_stamp, + ) + return row + + # Locking + def lock_root( + self, + root_id: np.uint64, + operation_id: np.uint64, + ) -> bool: + """Attempts to lock the latest version of a root node.""" + lock_expiry = self.graph_meta.graph_config.ROOT_LOCK_EXPIRY + lock_column = attributes.Concurrency.Lock + indefinite_lock_column = attributes.Concurrency.IndefiniteLock + filter_ = utils.get_root_lock_filter( + lock_column, lock_expiry, indefinite_lock_column + ) + + root_row = self._table.conditional_row( + serialize_uint64(root_id), filter_=filter_ + ) + # Set row lock if condition returns no results (state == False) + root_row.set_cell( + lock_column.family_id, + lock_column.key, + serialize_uint64(operation_id), + state=False, + timestamp=get_valid_timestamp(None), + ) + + # The lock was acquired when set_cell returns False (state) + lock_acquired = not root_row.commit() + if not lock_acquired: + row = self._read_byte_row(serialize_uint64(root_id), columns=lock_column) + l_operation_ids = [cell.value for cell in row] + self.logger.debug(f"Locked operation ids: {l_operation_ids}") + return lock_acquired + + def lock_root_indefinitely( + self, + root_id: np.uint64, + operation_id: np.uint64, + ) -> bool: + """Attempts to indefinitely lock the latest version of a root node.""" + lock_column = attributes.Concurrency.IndefiniteLock + filter_ = utils.get_indefinite_root_lock_filter(lock_column) + root_row = self._table.conditional_row( + serialize_uint64(root_id), filter_=filter_ + ) + # Set row lock if condition returns no results (state == False) + root_row.set_cell( + lock_column.family_id, + lock_column.key, + serialize_uint64(operation_id), + state=False, + timestamp=get_valid_timestamp(None), + ) + + # The lock was acquired when set_cell returns False (state) + lock_acquired = not root_row.commit() + if not lock_acquired: + row = self._read_byte_row(serialize_uint64(root_id), columns=lock_column) + l_operation_ids = [cell.value for cell in row] + self.logger.debug(f"Indefinitely locked operation ids: {l_operation_ids}") + return lock_acquired + + def lock_roots( + self, + root_ids: typing.Sequence[np.uint64], + operation_id: np.uint64, + future_root_ids_d: typing.Dict, + max_tries: int = 1, + waittime_s: float = 0.5, + ) -> typing.Tuple[bool, typing.Iterable]: + """Attempts to lock multiple nodes with same operation id""" + i_try = 0 + while i_try < max_tries: + lock_acquired = False + # Collect latest root ids + new_root_ids: typing.List[np.uint64] = [] + for root_id in root_ids: + future_root_ids = future_root_ids_d[root_id] + if not future_root_ids.size: + new_root_ids.append(root_id) + else: + new_root_ids.extend(future_root_ids) + + # Attempt to lock all latest root ids + root_ids = np.unique(new_root_ids) + for root_id in root_ids: + lock_acquired = self.lock_root(root_id, operation_id) + # Roll back locks if one root cannot be locked + if not lock_acquired: + for id_ in root_ids: + self.unlock_root(id_, operation_id) + break + + if lock_acquired: + return True, root_ids + time.sleep(waittime_s) + i_try += 1 + self.logger.debug(f"Try {i_try}") + return False, root_ids + + def lock_roots_indefinitely( + self, + root_ids: typing.Sequence[np.uint64], + operation_id: np.uint64, + future_root_ids_d: typing.Dict, + ) -> typing.Tuple[bool, typing.Iterable]: + """Attempts to indefinitely lock multiple nodes with same operation id""" + lock_acquired = False + # Collect latest root ids + new_root_ids: typing.List[np.uint64] = [] + for _id in root_ids: + future_root_ids = future_root_ids_d.get(_id) + if not future_root_ids.size: + new_root_ids.append(_id) + else: + new_root_ids.extend(future_root_ids) + + # Attempt to lock all latest root ids + failed_to_lock_id = None + root_ids = np.unique(new_root_ids) + for _id in root_ids: + self.logger.debug(f"operation {operation_id} root_id {_id}") + lock_acquired = self.lock_root_indefinitely(_id, operation_id) + # Roll back locks if one root cannot be locked + if not lock_acquired: + failed_to_lock_id = _id + for id_ in root_ids: + self.unlock_indefinitely_locked_root(id_, operation_id) + break + if lock_acquired: + return True, root_ids, failed_to_lock_id + return False, root_ids, failed_to_lock_id + + def unlock_root(self, root_id: np.uint64, operation_id: np.uint64): + """Unlocks root node that is locked with operation_id.""" + lock_column = attributes.Concurrency.Lock + expiry = self.graph_meta.graph_config.ROOT_LOCK_EXPIRY + root_row = self._table.conditional_row( + serialize_uint64(root_id), + filter_=utils.get_unlock_root_filter(lock_column, expiry, operation_id), + ) + # Delete row if conditions are met (state == True) + root_row.delete_cell(lock_column.family_id, lock_column.key, state=True) + return root_row.commit() + + def unlock_indefinitely_locked_root( + self, root_id: np.uint64, operation_id: np.uint64 + ): + """Unlocks root node that is indefinitely locked with operation_id.""" + lock_column = attributes.Concurrency.IndefiniteLock + # Get conditional row using the chained filter + root_row = self._table.conditional_row( + serialize_uint64(root_id), + filter_=utils.get_indefinite_unlock_root_filter(lock_column, operation_id), + ) + # Delete row if conditions are met (state == True) + root_row.delete_cell(lock_column.family_id, lock_column.key, state=True) + return root_row.commit() + + def renew_lock(self, root_id: np.uint64, operation_id: np.uint64) -> bool: + """Renews existing root node lock with operation_id to extend time.""" + lock_column = attributes.Concurrency.Lock + root_row = self._table.conditional_row( + serialize_uint64(root_id), + filter_=utils.get_renew_lock_filter(lock_column, operation_id), + ) + # Set row lock if condition returns a result (state == True) + root_row.set_cell( + lock_column.family_id, + lock_column.key, + lock_column.serialize(operation_id), + state=False, + ) + # The lock was acquired when set_cell returns True (state) + return not root_row.commit() + + def renew_locks(self, root_ids: np.uint64, operation_id: np.uint64) -> bool: + """Renews existing root node locks with operation_id to extend time.""" + for root_id in root_ids: + if not self.renew_lock(root_id, operation_id): + self.logger.warning(f"renew_lock failed - {root_id}") + return False + return True + + def get_lock_timestamp( + self, root_id: np.uint64, operation_id: np.uint64 + ) -> typing.Union[datetime, None]: + """Lock timestamp for a Root ID operation.""" + row = self.read_node(root_id, properties=attributes.Concurrency.Lock) + if len(row) == 0: + self.logger.warning(f"No lock found for {root_id}") + return None + if row[0].value != operation_id: + self.logger.warning(f"{root_id} not locked with {operation_id}") + return None + return row[0].timestamp + + def get_consolidated_lock_timestamp( + self, + root_ids: typing.Sequence[np.uint64], + operation_ids: typing.Sequence[np.uint64], + ) -> typing.Union[datetime, None]: + """Minimum of multiple lock timestamps.""" + time_stamps = [] + for root_id, operation_id in zip(root_ids, operation_ids): + time_stamp = self.get_lock_timestamp(root_id, operation_id) + if time_stamp is None: + return None + time_stamps.append(time_stamp) + if len(time_stamps) == 0: + return None + return np.min(time_stamps) + + # IDs + def create_node_ids( + self, chunk_id: np.uint64, size: int, root_chunk=False + ) -> np.ndarray: + """Generates a list of unique node IDs for the given chunk.""" + if root_chunk: + new_ids = self._get_root_segment_ids_range(chunk_id, size) + else: + low, high = self._get_ids_range( + serialize_uint64(chunk_id, counter=True), size + ) + low, high = basetypes.SEGMENT_ID.type(low), basetypes.SEGMENT_ID.type(high) + new_ids = np.arange(low, high + np.uint64(1), dtype=basetypes.SEGMENT_ID) + return new_ids | chunk_id + + def create_node_id( + self, chunk_id: np.uint64, root_chunk=False + ) -> basetypes.NODE_ID: + """Generate a unique node ID in the chunk.""" + return self.create_node_ids(chunk_id, 1, root_chunk=root_chunk)[0] + + def get_max_node_id( + self, chunk_id: basetypes.CHUNK_ID, root_chunk=False + ) -> basetypes.NODE_ID: + """Gets the current maximum segment ID in the chunk.""" + if root_chunk: + n_counters = np.uint64(2**8) + max_value = 0 + for counter in range(n_counters): + row = self._read_byte_row( + serialize_key(f"i{pad_node_id(chunk_id)}_{counter}"), + columns=attributes.Concurrency.Counter, + ) + val = ( + basetypes.SEGMENT_ID.type(row[0].value if row else 0) * n_counters + + counter + ) + max_value = val if val > max_value else max_value + return chunk_id | basetypes.SEGMENT_ID.type(max_value) + column = attributes.Concurrency.Counter + row = self._read_byte_row( + serialize_uint64(chunk_id, counter=True), columns=column + ) + return chunk_id | basetypes.SEGMENT_ID.type(row[0].value if row else 0) + + def create_operation_id(self): + """Generate a unique operation ID.""" + return self._get_ids_range(attributes.OperationLogs.key, 1)[1] + + def get_max_operation_id(self): + """Gets the current maximum operation ID.""" + column = attributes.Concurrency.Counter + row = self._read_byte_row(attributes.OperationLogs.key, columns=column) + return row[0].value if row else column.basetype(0) + + def get_compatible_timestamp( + self, time_stamp: datetime, round_up: bool = False + ) -> datetime: + return utils.get_google_compatible_time_stamp(time_stamp, round_up=round_up) + + # PRIVATE METHODS + def _create_column_families(self): + f = self._table.column_family("0") + f.create() + f = self._table.column_family("1", gc_rule=MaxVersionsGCRule(1)) + f.create() + f = self._table.column_family("2") + f.create() + f = self._table.column_family("3") + f.create() + + def _get_ids_range(self, key: bytes, size: int) -> typing.Tuple: + """Returns a range (min, max) of IDs for a given `key`.""" + column = attributes.Concurrency.Counter + row = self._table.append_row(key) + row.increment_cell_value(column.family_id, column.key, size) + row = row.commit() + high = column.deserialize(row[column.family_id][column.key][0][0]) + return high + np.uint64(1) - size, high + + def _get_root_segment_ids_range( + self, chunk_id: basetypes.CHUNK_ID, size: int = 1, counter: int = None + ) -> np.ndarray: + """Return unique segment ID for the root chunk.""" + n_counters = np.uint64(2**8) + counter = ( + np.uint64(counter % n_counters) + if counter + else np.uint64(np.random.randint(0, n_counters)) + ) + key = serialize_key(f"i{pad_node_id(chunk_id)}_{counter}") + min_, max_ = self._get_ids_range(key=key, size=size) + return np.arange( + min_ * n_counters + counter, + max_ * n_counters + np.uint64(1) + counter, + n_counters, + dtype=basetypes.SEGMENT_ID, + ) + + def _read_byte_rows( + self, + start_key: typing.Optional[bytes] = None, + end_key: typing.Optional[bytes] = None, + end_key_inclusive: bool = False, + row_keys: typing.Optional[typing.Iterable[bytes]] = None, + columns: typing.Optional[ + typing.Union[typing.Iterable[attributes._Attribute], attributes._Attribute] + ] = None, + start_time: typing.Optional[datetime] = None, + end_time: typing.Optional[datetime] = None, + end_time_inclusive: bool = False, + user_id: typing.Optional[str] = None, + ) -> typing.Dict[ + bytes, + typing.Union[ + typing.Dict[attributes._Attribute, typing.List[bigtable.row_data.Cell]], + typing.List[bigtable.row_data.Cell], + ], + ]: + """Main function for reading a row range or non-contiguous row sets from Bigtable using + `bytes` keys. + + Keyword Arguments: + start_key {typing.Optional[bytes]} -- The first row to be read, ignored if `row_keys` is set. + If None, no lower boundary is used. (default: {None}) + end_key {typing.Optional[bytes]} -- The end of the row range, ignored if `row_keys` is set. + If None, no upper boundary is used. (default: {None}) + end_key_inclusive {bool} -- Whether or not `end_key` itself should be included in the + request, ignored if `row_keys` is set or `end_key` is None. (default: {False}) + row_keys {typing.Optional[typing.Iterable[bytes]]} -- An `typing.Iterable` containing possibly + non-contiguous row keys. Takes precedence over `start_key` and `end_key`. + (default: {None}) + columns {typing.Optional[typing.Union[typing.Iterable[attributes._Attribute], attributes._Attribute]]} -- + typing.Optional filtering by columns to speed up the query. If `columns` is a single + column (not iterable), the column key will be omitted from the result. + (default: {None}) + start_time {typing.Optional[datetime]} -- Ignore cells with timestamp before + `start_time`. If None, no lower bound. (default: {None}) + end_time {typing.Optional[datetime]} -- Ignore cells with timestamp after `end_time`. + If None, no upper bound. (default: {None}) + end_time_inclusive {bool} -- Whether or not `end_time` itself should be included in the + request, ignored if `end_time` is None. (default: {False}) + user_id {typing.Optional[str]} -- Only return cells with userID equal to this + + Returns: + typing.Dict[bytes, typing.Union[typing.Dict[attributes._Attribute, typing.List[bigtable.row_data.Cell]], + typing.List[bigtable.row_data.Cell]]] -- + Returns a dictionary of `byte` rows as keys. Their value will be a mapping of + columns to a typing.List of cells (one cell per timestamp). Each cell has a `value` + property, which returns the deserialized field, and a `timestamp` property, which + returns the timestamp as `datetime` object. + If only a single `attributes._Attribute` was requested, the typing.List of cells will be + attached to the row dictionary directly (skipping the column dictionary). + """ + + # Create filters: Rows + row_set = RowSet() + if row_keys is not None: + row_set.row_keys = list(row_keys) + elif start_key is not None and end_key is not None: + row_set.add_row_range_from_keys( + start_key=start_key, + start_inclusive=True, + end_key=end_key, + end_inclusive=end_key_inclusive, + ) + else: + raise exceptions.PreconditionError( + "Need to either provide a valid set of rows, or" + " both, a start row and an end row." + ) + filter_ = utils.get_time_range_and_column_filter( + columns=columns, + start_time=start_time, + end_time=end_time, + end_inclusive=end_time_inclusive, + user_id=user_id, + ) + # Bigtable read with retries + rows = self._read(row_set=row_set, row_filter=filter_) + + # Deserialize cells + for row_key, column_dict in rows.items(): + for column, cell_entries in column_dict.items(): + for cell_entry in cell_entries: + cell_entry.value = column.deserialize(cell_entry.value) + # If no column array was requested, reattach single column's values directly to the row + if isinstance(columns, attributes._Attribute): + rows[row_key] = cell_entries + return rows + + def _read_byte_row( + self, + row_key: bytes, + columns: typing.Optional[ + typing.Union[typing.Iterable[attributes._Attribute], attributes._Attribute] + ] = None, + start_time: typing.Optional[datetime] = None, + end_time: typing.Optional[datetime] = None, + end_time_inclusive: bool = False, + ) -> typing.Union[ + typing.Dict[attributes._Attribute, typing.List[bigtable.row_data.Cell]], + typing.List[bigtable.row_data.Cell], + ]: + """Convenience function for reading a single row from Bigtable using its `bytes` keys. + + Arguments: + row_key {bytes} -- The row to be read. + + Keyword Arguments: + columns {typing.Optional[typing.Union[typing.Iterable[attributes._Attribute], attributes._Attribute]]} -- + typing.Optional filtering by columns to speed up the query. If `columns` is a single + column (not iterable), the column key will be omitted from the result. + (default: {None}) + start_time {typing.Optional[datetime]} -- Ignore cells with timestamp before + `start_time`. If None, no lower bound. (default: {None}) + end_time {typing.Optional[datetime]} -- Ignore cells with timestamp after `end_time`. + If None, no upper bound. (default: {None}) + end_time_inclusive {bool} -- Whether or not `end_time` itself should be included in the + request, ignored if `end_time` is None. (default: {False}) + + Returns: + typing.Union[typing.Dict[attributes._Attribute, typing.List[bigtable.row_data.Cell]], + typing.List[bigtable.row_data.Cell]] -- + Returns a mapping of columns to a typing.List of cells (one cell per timestamp). Each cell + has a `value` property, which returns the deserialized field, and a `timestamp` + property, which returns the timestamp as `datetime` object. + If only a single `attributes._Attribute` was requested, the typing.List of cells is returned + directly. + """ + row = self._read_byte_rows( + row_keys=[row_key], + columns=columns, + start_time=start_time, + end_time=end_time, + end_time_inclusive=end_time_inclusive, + ) + return ( + row.get(row_key, []) + if isinstance(columns, attributes._Attribute) + else row.get(row_key, {}) + ) + + def _execute_read_thread(self, args: typing.Tuple[Table, RowSet, RowFilter]): + table, row_set, row_filter = args + if not row_set.row_keys and not row_set.row_ranges: + # Check for everything falsy, because Bigtable considers even empty + # lists of row_keys as no upper/lower bound! + return {} + range_read = table.read_rows(row_set=row_set, filter_=row_filter) + res = {v.row_key: utils.partial_row_data_to_column_dict(v) for v in range_read} + return res + + def _read( + self, row_set: RowSet, row_filter: RowFilter = None + ) -> typing.Dict[bytes, typing.Dict[attributes._Attribute, PartialRowData]]: + """Core function to read rows from Bigtable. Uses standard Bigtable retry logic + :param row_set: BigTable RowSet + :param row_filter: BigTable RowFilter + :return: typing.Dict[bytes, typing.Dict[attributes._Attribute, bigtable.row_data.PartialRowData]] + """ + # FIXME: Bigtable limits the length of the serialized request to 512 KiB. We should + # calculate this properly (range_read.request.SerializeToString()), but this estimate is + # good enough for now + + n_subrequests = max( + 1, int(np.ceil(len(row_set.row_keys) / self._max_row_key_count)) + ) + n_threads = min(n_subrequests, 2 * mu.n_cpus) + + row_sets = [] + for i in range(n_subrequests): + r = RowSet() + r.row_keys = row_set.row_keys[ + i * self._max_row_key_count : (i + 1) * self._max_row_key_count + ] + row_sets.append(r) + + # Don't forget the original RowSet's row_ranges + row_sets[0].row_ranges = row_set.row_ranges + responses = mu.multithread_func( + self._execute_read_thread, + params=((self._table, r, row_filter) for r in row_sets), + debug=n_threads == 1, + n_threads=n_threads, + ) + + combined_response = {} + for resp in responses: + combined_response.update(resp) + return combined_response diff --git a/pychunkedgraph/graph/client/bigtable/utils.py b/pychunkedgraph/graph/client/bigtable/utils.py new file mode 100644 index 000000000..2d30eeb32 --- /dev/null +++ b/pychunkedgraph/graph/client/bigtable/utils.py @@ -0,0 +1,304 @@ +from typing import Dict +from typing import Union +from typing import Iterable +from typing import Optional +from datetime import datetime +from datetime import timedelta + +import numpy as np +from google.cloud.bigtable.row_data import PartialRowData +from google.cloud.bigtable.row_filters import RowFilter +from google.cloud.bigtable.row_filters import PassAllFilter +from google.cloud.bigtable.row_filters import BlockAllFilter +from google.cloud.bigtable.row_filters import TimestampRange +from google.cloud.bigtable.row_filters import RowFilterChain +from google.cloud.bigtable.row_filters import RowFilterUnion +from google.cloud.bigtable.row_filters import ValueRangeFilter +from google.cloud.bigtable.row_filters import CellsRowLimitFilter +from google.cloud.bigtable.row_filters import ColumnRangeFilter +from google.cloud.bigtable.row_filters import TimestampRangeFilter +from google.cloud.bigtable.row_filters import ConditionalRowFilter +from google.cloud.bigtable.row_filters import ColumnQualifierRegexFilter + +from ... import attributes + + +def partial_row_data_to_column_dict( + partial_row_data: PartialRowData, +) -> Dict[attributes._Attribute, PartialRowData]: + new_column_dict = {} + for family_id, column_dict in partial_row_data._cells.items(): + for column_key, column_values in column_dict.items(): + column = attributes.from_key(family_id, column_key) + new_column_dict[column] = column_values + return new_column_dict + + +def get_google_compatible_time_stamp( + time_stamp: datetime, round_up: bool = False +) -> datetime: + """ + Makes a datetime time stamp compatible with googles' services. + Google restricts the accuracy of time stamps to milliseconds. Hence, the + microseconds are cut of. By default, time stamps are rounded to the lower + number. + """ + micro_s_gap = timedelta(microseconds=time_stamp.microsecond % 1000) + if micro_s_gap == 0: + return time_stamp + if round_up: + time_stamp += timedelta(microseconds=1000) - micro_s_gap + else: + time_stamp -= micro_s_gap + return time_stamp + + +def _get_column_filter( + columns: Union[Iterable[attributes._Attribute], attributes._Attribute] = None +) -> RowFilter: + """Generates a RowFilter that accepts the specified columns""" + if isinstance(columns, attributes._Attribute): + return ColumnRangeFilter( + columns.family_id, start_column=columns.key, end_column=columns.key + ) + elif len(columns) == 1: + return ColumnRangeFilter( + columns[0].family_id, start_column=columns[0].key, end_column=columns[0].key + ) + return RowFilterUnion( + [ + ColumnRangeFilter(col.family_id, start_column=col.key, end_column=col.key) + for col in columns + ] + ) + + +def _get_user_filter(user_id: str): + """generates a ColumnRegEx Filter which filters user ids + + Args: + user_id (str): userID to select for + """ + + condition = RowFilterChain( + [ + ColumnQualifierRegexFilter(attributes.OperationLogs.UserID.key), + ValueRangeFilter(str.encode(user_id), str.encode(user_id)), + CellsRowLimitFilter(1), + ] + ) + + conditional_filter = ConditionalRowFilter( + base_filter=condition, + true_filter=PassAllFilter(True), + false_filter=BlockAllFilter(True), + ) + return conditional_filter + + +def _get_time_range_filter( + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + end_inclusive: bool = True, +) -> RowFilter: + """Generates a TimeStampRangeFilter which is inclusive for start and (optionally) end. + + :param start: + :param end: + :return: + """ + # Comply to resolution of BigTables TimeRange + if start_time is not None: + start_time = get_google_compatible_time_stamp(start_time, round_up=False) + if end_time is not None: + end_time = get_google_compatible_time_stamp(end_time, round_up=end_inclusive) + return TimestampRangeFilter(TimestampRange(start=start_time, end=end_time)) + + +def get_time_range_and_column_filter( + columns: Optional[ + Union[Iterable[attributes._Attribute], attributes._Attribute] + ] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + end_inclusive: bool = False, + user_id: Optional[str] = None, +) -> RowFilter: + time_filter = _get_time_range_filter( + start_time=start_time, end_time=end_time, end_inclusive=end_inclusive + ) + filters = [time_filter] + if columns is not None: + if len(columns) == 0: + raise ValueError( + f"Empty column filter {columns} is ambiguous. Pass `None` if no column filter should be applied." + ) + column_filter = _get_column_filter(columns) + filters = [column_filter, time_filter] + if user_id is not None: + user_filter = _get_user_filter(user_id=user_id) + filters.append(user_filter) + if len(filters) > 1: + return RowFilterChain(filters) + return filters[0] + + +def get_root_lock_filter( + lock_column, lock_expiry, indefinite_lock_column +) -> ConditionalRowFilter: + time_cutoff = datetime.utcnow() - lock_expiry + # Comply to resolution of BigTables TimeRange + time_cutoff -= timedelta(microseconds=time_cutoff.microsecond % 1000) + time_filter = TimestampRangeFilter(TimestampRange(start=time_cutoff)) + + # Build a column filter which tests if a lock was set (== lock column + # exists) and if it is still valid (timestamp younger than + # LOCK_EXPIRED_TIME_DELTA) and if there is no new parent (== new_parents + # exists) + lock_key_filter = ColumnRangeFilter( + column_family_id=lock_column.family_id, + start_column=lock_column.key, + end_column=lock_column.key, + inclusive_start=True, + inclusive_end=True, + ) + + indefinite_lock_key_filter = ColumnRangeFilter( + column_family_id=indefinite_lock_column.family_id, + start_column=indefinite_lock_column.key, + end_column=indefinite_lock_column.key, + inclusive_start=True, + inclusive_end=True, + ) + + new_parents_column = attributes.Hierarchy.NewParent + new_parents_key_filter = ColumnRangeFilter( + column_family_id=new_parents_column.family_id, + start_column=new_parents_column.key, + end_column=new_parents_column.key, + inclusive_start=True, + inclusive_end=True, + ) + + temporal_lock_filter = RowFilterChain([time_filter, lock_key_filter]) + return ConditionalRowFilter( + base_filter=RowFilterUnion([indefinite_lock_key_filter, temporal_lock_filter]), + true_filter=PassAllFilter(True), + false_filter=new_parents_key_filter, + ) + + +def get_indefinite_root_lock_filter(lock_column) -> ConditionalRowFilter: + lock_key_filter = ColumnRangeFilter( + column_family_id=lock_column.family_id, + start_column=lock_column.key, + end_column=lock_column.key, + inclusive_start=True, + inclusive_end=True, + ) + + new_parents_column = attributes.Hierarchy.NewParent + new_parents_key_filter = ColumnRangeFilter( + column_family_id=new_parents_column.family_id, + start_column=new_parents_column.key, + end_column=new_parents_column.key, + inclusive_start=True, + inclusive_end=True, + ) + + return ConditionalRowFilter( + base_filter=lock_key_filter, + true_filter=PassAllFilter(True), + false_filter=new_parents_key_filter, + ) + + +def get_renew_lock_filter( + lock_column: attributes._Attribute, operation_id: np.uint64 +) -> ConditionalRowFilter: + new_parents_column = attributes.Hierarchy.NewParent + operation_id_b = lock_column.serialize(operation_id) + + # Build a column filter which tests if a lock was set (== lock column + # exists) and if the given operation_id is still the active lock holder + # and there is no new parent (== new_parents column exists). The latter + # is not necessary but we include it as a backup to prevent things + # from going really bad. + + column_key_filter = ColumnRangeFilter( + column_family_id=lock_column.family_id, + start_column=lock_column.key, + end_column=lock_column.key, + inclusive_start=True, + inclusive_end=True, + ) + + value_filter = ValueRangeFilter( + start_value=operation_id_b, + end_value=operation_id_b, + inclusive_start=True, + inclusive_end=True, + ) + + new_parents_key_filter = ColumnRangeFilter( + column_family_id=new_parents_column.family_id, + start_column=new_parents_column.key, + end_column=new_parents_column.key, + inclusive_start=True, + inclusive_end=True, + ) + + return ConditionalRowFilter( + base_filter=RowFilterChain([column_key_filter, value_filter]), + true_filter=new_parents_key_filter, + false_filter=PassAllFilter(True), + ) + + +def get_unlock_root_filter(lock_column, lock_expiry, operation_id) -> RowFilterChain: + time_cutoff = datetime.utcnow() - lock_expiry + # Comply to resolution of BigTables TimeRange + time_cutoff -= timedelta(microseconds=time_cutoff.microsecond % 1000) + time_filter = TimestampRangeFilter(TimestampRange(start=time_cutoff)) + + # Build a column filter which tests if a lock was set (== lock column + # exists) and if it is still valid (timestamp younger than + # LOCK_EXPIRED_TIME_DELTA) and if the given operation_id is still + # the active lock holder + column_key_filter = ColumnRangeFilter( + column_family_id=lock_column.family_id, + start_column=lock_column.key, + end_column=lock_column.key, + inclusive_start=True, + inclusive_end=True, + ) + + value_filter = ValueRangeFilter( + start_value=lock_column.serialize(operation_id), + end_value=lock_column.serialize(operation_id), + inclusive_start=True, + inclusive_end=True, + ) + + # Chain these filters together + return RowFilterChain([time_filter, column_key_filter, value_filter]) + + +def get_indefinite_unlock_root_filter(lock_column, operation_id) -> RowFilterChain: + column_key_filter = ColumnRangeFilter( + column_family_id=lock_column.family_id, + start_column=lock_column.key, + end_column=lock_column.key, + inclusive_start=True, + inclusive_end=True, + ) + + value_filter = ValueRangeFilter( + start_value=lock_column.serialize(operation_id), + end_value=lock_column.serialize(operation_id), + inclusive_start=True, + inclusive_end=True, + ) + + # Chain these filters together + return RowFilterChain([column_key_filter, value_filter]) diff --git a/pychunkedgraph/graph/client/utils.py b/pychunkedgraph/graph/client/utils.py new file mode 100644 index 000000000..12eebec82 --- /dev/null +++ b/pychunkedgraph/graph/client/utils.py @@ -0,0 +1,3 @@ +""" +Common client util functions +""" \ No newline at end of file diff --git a/pychunkedgraph/examples/parallel_test/__init__.py b/pychunkedgraph/graph/connectivity/__init__.py similarity index 100% rename from pychunkedgraph/examples/parallel_test/__init__.py rename to pychunkedgraph/graph/connectivity/__init__.py diff --git a/pychunkedgraph/graph/connectivity/cross_edges.py b/pychunkedgraph/graph/connectivity/cross_edges.py new file mode 100644 index 000000000..8aa52a9f1 --- /dev/null +++ b/pychunkedgraph/graph/connectivity/cross_edges.py @@ -0,0 +1,219 @@ +import time +import math +import multiprocessing as mp +from collections import defaultdict +from typing import Optional +from typing import Sequence +from typing import List +from typing import Dict + +import numpy as np +from multiwrapper.multiprocessing_utils import multiprocess_func + +from .. import attributes +from ..types import empty_2d +from ..utils import basetypes +from ..utils import serializers +from ..chunkedgraph import ChunkedGraph +from ..utils.generic import get_valid_timestamp +from ..utils.generic import filter_failed_node_ids +from ..chunks.atomic import get_touching_atomic_chunks +from ..chunks.atomic import get_bounding_atomic_chunks +from ...utils.general import chunked + + +def get_children_chunk_cross_edges( + cg, layer, chunk_coord, *, use_threads=True +) -> np.ndarray: + """ + Cross edges that connect children chunks. + The edges are between node IDs in the given layer (not atomic). + """ + atomic_chunks = get_touching_atomic_chunks(cg.meta, layer, chunk_coord) + if not len(atomic_chunks): + return [] + + print(f"touching atomic chunk count {len(atomic_chunks)}") + if not use_threads: + return _get_children_chunk_cross_edges(cg, atomic_chunks, layer - 1) + + print("get_children_chunk_cross_edges, atomic chunks", len(atomic_chunks)) + with mp.Manager() as manager: + edge_ids_shared = manager.list() + edge_ids_shared.append(empty_2d) + + task_size = int(math.ceil(len(atomic_chunks) / mp.cpu_count() / 10)) + chunked_l2chunk_list = chunked(atomic_chunks, task_size) + multi_args = [] + for atomic_chunks in chunked_l2chunk_list: + multi_args.append( + (edge_ids_shared, cg.get_serialized_info(), atomic_chunks, layer - 1) + ) + + multiprocess_func( + _get_children_chunk_cross_edges_helper, + multi_args, + n_threads=min(len(multi_args), mp.cpu_count()), + ) + + cross_edges = np.concatenate(edge_ids_shared) + if cross_edges.size: + return np.unique(cross_edges, axis=0) + return cross_edges + + +def _get_children_chunk_cross_edges_helper(args) -> None: + edge_ids_shared, cg_info, atomic_chunks, layer = args + cg = ChunkedGraph(**cg_info) + edge_ids_shared.append(_get_children_chunk_cross_edges(cg, atomic_chunks, layer)) + + +def _get_children_chunk_cross_edges(cg, atomic_chunks, layer) -> None: + print( + f"_get_children_chunk_cross_edges {layer} atomic_chunks count {len(atomic_chunks)}" + ) + cross_edges = [empty_2d] + for layer2_chunk in atomic_chunks: + edges = _read_atomic_chunk_cross_edges(cg, layer2_chunk, layer) + cross_edges.append(edges) + + cross_edges = np.concatenate(cross_edges) + if not cross_edges.size: + return empty_2d + print(f"getting roots at stop_layer {layer} {cross_edges.shape}") + cross_edges[:, 0] = cg.get_roots(cross_edges[:, 0], stop_layer=layer, ceil=False) + cross_edges[:, 1] = cg.get_roots(cross_edges[:, 1], stop_layer=layer, ceil=False) + result = np.unique(cross_edges, axis=0) if cross_edges.size else empty_2d + print(f"_get_children_chunk_cross_edges done {result.shape}") + return result + + +def _read_atomic_chunk_cross_edges( + cg, chunk_coord: Sequence[int], cross_edge_layer: int +) -> np.ndarray: + cross_edge_col = attributes.Connectivity.CrossChunkEdge[cross_edge_layer] + range_read, l2ids = _read_atomic_chunk(cg, chunk_coord, [cross_edge_layer]) + + parent_neighboring_chunk_supervoxels_d = defaultdict(list) + for l2id in l2ids: + if not cross_edge_col in range_read[l2id]: + continue + edges = range_read[l2id][cross_edge_col][0].value + parent_neighboring_chunk_supervoxels_d[l2id] = edges[:, 1] + + cross_edges = [empty_2d] + for l2id in parent_neighboring_chunk_supervoxels_d: + nebor_svs = parent_neighboring_chunk_supervoxels_d[l2id] + chunk_parent_ids = np.array([l2id] * len(nebor_svs), dtype=basetypes.NODE_ID) + cross_edges.append(np.vstack([chunk_parent_ids, nebor_svs]).T) + cross_edges = np.concatenate(cross_edges) + return cross_edges + + +def get_chunk_nodes_cross_edge_layer( + cg, layer: int, chunk_coord: Sequence[int], use_threads=True +) -> Dict: + """ + gets nodes in a chunk that are part of cross chunk edges + return_type dict {node_id: layer} + the lowest layer (>= current layer) at which a node_id is part of a cross edge + """ + print("get_bounding_atomic_chunks") + atomic_chunks = get_bounding_atomic_chunks(cg.meta, layer, chunk_coord) + print("get_bounding_atomic_chunks complete") + if not len(atomic_chunks): + return {} + + if not use_threads: + return _get_chunk_nodes_cross_edge_layer(cg, atomic_chunks, layer) + + print("divide tasks") + cg_info = cg.get_serialized_info() + manager = mp.Manager() + ids_l_shared = manager.list() + layers_l_shared = manager.list() + task_size = int(math.ceil(len(atomic_chunks) / mp.cpu_count() / 10)) + chunked_l2chunk_list = chunked(atomic_chunks, task_size) + multi_args = [] + for atomic_chunks in chunked_l2chunk_list: + multi_args.append( + (ids_l_shared, layers_l_shared, cg_info, atomic_chunks, layer) + ) + print("divide tasks complete") + + multiprocess_func( + _get_chunk_nodes_cross_edge_layer_helper, + multi_args, + n_threads=min(len(multi_args), mp.cpu_count()), + ) + + node_layer_d_shared = manager.dict() + _find_min_layer(node_layer_d_shared, ids_l_shared, layers_l_shared) + print("_find_min_layer complete") + return node_layer_d_shared + + +def _get_chunk_nodes_cross_edge_layer_helper(args): + ids_l_shared, layers_l_shared, cg_info, atomic_chunks, layer = args + cg = ChunkedGraph(**cg_info) + node_layer_d = _get_chunk_nodes_cross_edge_layer(cg, atomic_chunks, layer) + ids_l_shared.append(np.fromiter(node_layer_d.keys(), dtype=basetypes.NODE_ID)) + layers_l_shared.append(np.fromiter(node_layer_d.values(), dtype=np.uint8)) + + +def _get_chunk_nodes_cross_edge_layer(cg, atomic_chunks, layer): + atomic_node_layer_d = {} + for atomic_chunk in atomic_chunks: + chunk_node_layer_d = _read_atomic_chunk_cross_edge_nodes( + cg, atomic_chunk, range(layer, cg.meta.layer_count + 1) + ) + atomic_node_layer_d.update(chunk_node_layer_d) + + l2ids = np.fromiter(atomic_node_layer_d.keys(), dtype=basetypes.NODE_ID) + parents = cg.get_roots(l2ids, stop_layer=layer - 1, ceil=False) + layers = np.fromiter(atomic_node_layer_d.values(), dtype=int) + + node_layer_d = defaultdict(lambda: cg.meta.layer_count) + for i, parent in enumerate(parents): + node_layer_d[parent] = min(node_layer_d[parent], layers[i]) + return node_layer_d + + +def _read_atomic_chunk_cross_edge_nodes(cg, chunk_coord, cross_edge_layers): + node_layer_d = {} + range_read, l2ids = _read_atomic_chunk(cg, chunk_coord, cross_edge_layers) + for l2id in l2ids: + for layer in cross_edge_layers: + if attributes.Connectivity.CrossChunkEdge[layer] in range_read[l2id]: + node_layer_d[l2id] = layer + break + return node_layer_d + + +def _find_min_layer(node_layer_d_shared, ids_l_shared, layers_l_shared): + node_ids = np.concatenate(ids_l_shared) + layers = np.concatenate(layers_l_shared) + for i, node_id in enumerate(node_ids): + layer = node_layer_d_shared.get(node_id, layers[i]) + node_layer_d_shared[node_id] = min(layer, layers[i]) + + +def _read_atomic_chunk(cg, chunk_coord, layers): + x, y, z = chunk_coord + child_col = attributes.Hierarchy.Child + range_read = cg.range_read_chunk( + cg.get_chunk_id(layer=2, x=x, y=y, z=z), + properties=[child_col] + + [attributes.Connectivity.CrossChunkEdge[l] for l in layers], + ) + + row_ids = [] + max_children_ids = [] + for row_id, row_data in range_read.items(): + row_ids.append(row_id) + max_children_ids.append(np.max(row_data[child_col][0].value)) + + row_ids = np.array(row_ids, dtype=basetypes.NODE_ID) + segment_ids = np.array([cg.get_segment_id(r_id) for r_id in row_ids]) + l2ids = filter_failed_node_ids(row_ids, segment_ids, max_children_ids) + return range_read, l2ids diff --git a/pychunkedgraph/graph/connectivity/nodes.py b/pychunkedgraph/graph/connectivity/nodes.py new file mode 100644 index 000000000..3d1f57c6c --- /dev/null +++ b/pychunkedgraph/graph/connectivity/nodes.py @@ -0,0 +1,29 @@ +from typing import Iterable +from itertools import combinations + +from ..types import Agglomeration +from ...utils.general import reverse_dictionary + + +def edge_exists(agglomerations: Iterable[Agglomeration]): + """ + Determine if there is an edge (in-active) + between atleast two of the given nodes (L2 agglomerations). + """ + supervoxel_parent_d = {} + for agg in agglomerations: + supervoxel_parent_d.update( + zip(agg.supervoxels, [agg.node_id] * len(agg.supervoxels)) + ) + + for agg_1, agg_2 in combinations(agglomerations, 2): + targets1 = agg_1.out_edges[:, 1] + targets2 = agg_2.out_edges[:, 1] + + for t1, t2 in zip(targets1, targets2): + if ( + supervoxel_parent_d[t1] == agg_2.node_id + or supervoxel_parent_d[t2] == agg_1.node_id + ): + return True + return False diff --git a/pychunkedgraph/graph/connectivity/search.py b/pychunkedgraph/graph/connectivity/search.py new file mode 100644 index 000000000..bd3faf227 --- /dev/null +++ b/pychunkedgraph/graph/connectivity/search.py @@ -0,0 +1,47 @@ +import random +from typing import List + +import numpy as np +from graph_tool.search import bfs_search +from graph_tool.search import BFSVisitor +from graph_tool.search import StopSearch + +from ..utils.basetypes import NODE_ID + + +class TargetVisitor(BFSVisitor): + def __init__(self, target, reachable): + self.target = target + self.reachable = reachable + + def discover_vertex(self, u): + if u == self.target: + self.reachable[u] = 1 + raise StopSearch + + +def check_reachability(g, sv1s: np.ndarray, sv2s: np.ndarray, original_ids: np.ndarray) -> np.ndarray: + """ + g: graph tool Graph instance with ids 0 to N-1 where N = vertex count + original_ids: sorted ChunkedGraph supervoxel ids + (to identify corresponding ids in graph tool) + for each pair (sv1, sv2) check if a path exists (BFS) + """ + # mapping from original ids to graph tool ids + original_ids_d = { + sv_id: index for sv_id, index in zip(original_ids, range(len(original_ids))) + } + reachable = g.new_vertex_property("int", val=0) + + def _check_reachability(source, target): + bfs_search(g, source, TargetVisitor(target, reachable)) + return reachable[target] + + return np.array( + [ + _check_reachability(original_ids_d[source], original_ids_d[target]) + for source, target in zip(sv1s, sv2s) + ], + dtype=bool, + ) + diff --git a/pychunkedgraph/admin/bigtable_admin.py b/pychunkedgraph/graph/connectivity/utils.py similarity index 100% rename from pychunkedgraph/admin/bigtable_admin.py rename to pychunkedgraph/graph/connectivity/utils.py diff --git a/pychunkedgraph/graph/cutting.py b/pychunkedgraph/graph/cutting.py new file mode 100644 index 000000000..8b1583871 --- /dev/null +++ b/pychunkedgraph/graph/cutting.py @@ -0,0 +1,682 @@ +import collections +import fastremap +import numpy as np +import itertools +import logging +import time +import graph_tool +import graph_tool.flow + +from typing import Dict +from typing import Tuple +from typing import Optional +from typing import Sequence +from typing import Iterable + +from .utils import flatgraph +from .utils import basetypes +from .utils.generic import get_bounding_box +from .edges import Edges +from .exceptions import PreconditionError +from .exceptions import PostconditionError + +DEBUG_MODE = False + + +class IsolatingCutException(Exception): + """Raised when mincut would split off one of the labeled supervoxel exactly. + This is used to trigger a PostconditionError with a custom message. + """ + + pass + + +def merge_cross_chunk_edges_graph_tool( + edges: Iterable[Sequence[np.uint64]], affs: Sequence[np.uint64] +): + """Merges cross chunk edges + :param edges: n x 2 array of uint64s + :param affs: float array of length n + :return: + """ + # mask for edges that have to be merged + cross_chunk_edge_mask = np.isinf(affs) + # graph with edges that have to be merged + graph, _, _, unique_supervoxel_ids = flatgraph.build_gt_graph( + edges[cross_chunk_edge_mask], make_directed=True + ) + + # connected components in this graph will be combined in one component + ccs = flatgraph.connected_components(graph) + remapping = {} + mapping = [] + + for cc in ccs: + nodes = unique_supervoxel_ids[cc] + rep_node = np.min(nodes) + remapping[rep_node] = nodes + rep_nodes = np.ones(len(nodes), dtype=np.uint64).reshape(-1, 1) * rep_node + m = np.concatenate([nodes.reshape(-1, 1), rep_nodes], axis=1) + mapping.append(m) + + if len(mapping) > 0: + mapping = np.concatenate(mapping) + u_nodes = np.unique(edges) + u_unmapped_nodes = u_nodes[~np.in1d(u_nodes, mapping)] + unmapped_mapping = np.concatenate( + [u_unmapped_nodes.reshape(-1, 1), u_unmapped_nodes.reshape(-1, 1)], axis=1 + ) + if len(mapping) > 0: + complete_mapping = np.concatenate([mapping, unmapped_mapping], axis=0) + else: + complete_mapping = unmapped_mapping + + sort_idx = np.argsort(complete_mapping[:, 0]) + idx = np.searchsorted(complete_mapping[:, 0], edges, sorter=sort_idx) + mapped_edges = np.asarray(complete_mapping[:, 1])[sort_idx][idx] + mapped_edges = mapped_edges[~cross_chunk_edge_mask] + mapped_affs = affs[~cross_chunk_edge_mask] + return mapped_edges, mapped_affs, mapping, complete_mapping, remapping + + +class LocalMincutGraph: + """ + Helper class for mincut computation. Used by the mincut_graph_tool function to: + (1) set up a local graph-tool graph, (2) compute a mincut, (3) ensure required conditions hold, + and (4) return the ChunkedGraph edges to be removed. + """ + + def __init__( + self, + cg_edges, + cg_affs, + cg_sources, + cg_sinks, + split_preview=False, + path_augment=True, + disallow_isolating_cut=True, + logger=None, + ): + self.cg_edges = cg_edges + self.split_preview = split_preview + self.logger = logger + self.path_augment = path_augment + self.disallow_isolating_cut = disallow_isolating_cut + + time_start = time.time() + + # Stitch supervoxels across chunk boundaries and represent those that are + # connected with a cross chunk edge with a single id. This may cause id + # changes among sinks and sources that need to be taken care of. + ( + mapped_edges, + mapped_affs, + cross_chunk_edge_mapping, + complete_mapping, + self.cross_chunk_edge_remapping, + ) = merge_cross_chunk_edges_graph_tool(cg_edges, cg_affs) + + dt = time.time() - time_start + if logger is not None: + logger.debug("Cross edge merging: %.2fms" % (dt * 1000)) + time_start = time.time() + + if len(mapped_edges) == 0: + raise PostconditionError( + f"Local graph somehow only contains cross chunk edges" + ) + + if len(cross_chunk_edge_mapping) > 0: + assert ( + np.unique(cross_chunk_edge_mapping[:, 0], return_counts=True)[1].max() + == 1 + ) + + # Map cg sources and sinks with the cross chunk edge mapping + self.sources = fastremap.remap_from_array_kv( + np.array(cg_sources), complete_mapping[:, 0], complete_mapping[:, 1] + ) + self.sinks = fastremap.remap_from_array_kv( + np.array(cg_sinks), complete_mapping[:, 0], complete_mapping[:, 1] + ) + + self._build_gt_graph(mapped_edges, mapped_affs) + + self.source_path_vertices = self.source_graph_ids + self.sink_path_vertices = self.sink_graph_ids + + dt = time.time() - time_start + if logger is not None: + logger.debug("Graph creation: %.2fms" % (dt * 1000)) + + self._create_fake_edge_property(mapped_affs) + + def _build_gt_graph(self, edges, affs): + """ + Create the graphs that will be used to compute the mincut. + """ + + # Assemble graph without infinite-affinity edges + ( + self.weighted_graph_raw, + self.capacities_raw, + self.gt_edges_raw, + _, + ) = flatgraph.build_gt_graph(edges, affs, make_directed=True) + + self.source_edges = list(itertools.product(self.sources, self.sources)) + self.sink_edges = list(itertools.product(self.sinks, self.sinks)) + + # Assemble edges: Edges after remapping combined with fake infinite affinity + # edges between sinks and sources + comb_edges = np.concatenate([edges, self.source_edges, self.sink_edges]) + comb_affs = np.concatenate( + [ + affs, + [np.finfo(np.float32).max] + * (len(self.source_edges) + len(self.sink_edges)), + ] + ) + + # To make things easier for everyone involved, we map the ids to + # [0, ..., len(unique_supervoxel_ids) - 1] + # Generate weighted graph with graph_tool + ( + self.weighted_graph, + self.capacities, + self.gt_edges, + self.unique_supervoxel_ids, + ) = flatgraph.build_gt_graph(comb_edges, comb_affs, make_directed=True) + + self.source_graph_ids = np.where( + np.in1d(self.unique_supervoxel_ids, self.sources) + )[0] + self.sink_graph_ids = np.where(np.in1d(self.unique_supervoxel_ids, self.sinks))[ + 0 + ] + + if self.logger is not None: + self.logger.debug(f"{self.sinks}, {self.sink_graph_ids}") + self.logger.debug(f"{self.sources}, {self.source_graph_ids}") + + def _compute_mincut_direct(self): + """Uses additional edges directly between source/sink points.""" + self._filter_graph_connected_components() + src, tgt = ( + self.weighted_graph.vertex(self.source_graph_ids[0]), + self.weighted_graph.vertex(self.sink_graph_ids[0]), + ) + + residuals = graph_tool.flow.push_relabel_max_flow( + self.weighted_graph, src, tgt, self.capacities + ) + partition = graph_tool.flow.min_st_cut( + self.weighted_graph, src, self.capacities, residuals + ) + return partition + + def _augment_mincut_capacity(self): + """Increase affinities along all pairs shortest paths between sources/sinks + in the supervoxel graph. + """ + try: + paths_v_s, paths_e_s, invaff_s = flatgraph.compute_filtered_paths( + self.weighted_graph_raw, + self.capacities_raw, + self.source_graph_ids, + self.sink_graph_ids, + ) + paths_v_y, paths_e_y, invaff_y = flatgraph.compute_filtered_paths( + self.weighted_graph_raw, + self.capacities_raw, + self.sink_graph_ids, + self.source_graph_ids, + ) + except AssertionError: + raise PreconditionError( + "Paths between source or sink points irreparably overlap other labels from other side. " + "Check that labels are correct and consider spreading points out farther." + ) + + paths_e_s_no, paths_e_y_no, do_check = flatgraph.remove_overlapping_edges( + paths_v_s, paths_e_s, paths_v_y, paths_e_y + ) + if do_check: + e_connected = flatgraph.check_connectedness(paths_v_s, paths_e_s_no) + y_connected = flatgraph.check_connectedness(paths_v_y, paths_e_y_no) + if e_connected is False or y_connected is False: + try: + paths_e_s_no, paths_e_y_no = self.rerun_paths_without_overlap( + paths_v_s, + paths_e_s, + invaff_s, + paths_v_y, + paths_e_y, + invaff_y, + ) + except AssertionError: + raise PreconditionError( + "Paths between source point pairs and sink point pairs overlapped irreparably. " + "Consider doing cut in multiple parts." + ) + + self.source_path_vertices = flatgraph.flatten_edge_list(paths_e_s_no) + self.sink_path_vertices = flatgraph.flatten_edge_list(paths_e_y_no) + + adj_capacity = flatgraph.adjust_affinities( + self.weighted_graph_raw, self.capacities_raw, paths_e_s_no + paths_e_y_no + ) + return adj_capacity + + def rerun_paths_without_overlap( + self, + paths_v_s, + paths_e_s, + invaff_s, + paths_v_y, + paths_e_y, + invaff_y, + invert_winner=False, + ): + + # smaller distance means larger affinity + s_wins = flatgraph.harmonic_mean_paths( + invaff_s + ) < flatgraph.harmonic_mean_paths(invaff_y) + if invert_winner: + s_wins = not s_wins + + # Omit winning team vertices from graph + try: + if s_wins: + paths_e_s_no = paths_e_s + omit_verts = [int(v) for v in itertools.chain.from_iterable(paths_v_s)] + _, paths_e_y_no, _ = flatgraph.compute_filtered_paths( + self.weighted_graph_raw, + self.capacities_raw, + self.sink_graph_ids, + omit_verts, + ) + + else: + omit_verts = [int(v) for v in itertools.chain.from_iterable(paths_v_y)] + _, paths_e_s_no, _ = flatgraph.compute_filtered_paths( + self.weighted_graph_raw, + self.capacities_raw, + self.source_graph_ids, + omit_verts, + ) + paths_e_y_no = paths_e_y + except AssertionError: + # If no path is found and this hasn't been tried before, try giving the overlap to the other team and finding paths + if not invert_winner: + paths_e_s_no, paths_e_y_no = self.rerun_paths_without_overlap( + paths_v_s, + paths_e_s, + invaff_s, + paths_v_y, + paths_e_y, + invaff_y, + invert_winner=True, + ) + else: + # Otherwise propagate the AssertionError back up + raise AssertionError + return paths_e_s_no, paths_e_y_no + + def _compute_mincut_path_augmented(self): + """Compute mincut using edges found from a shortest-path search.""" + adj_capacity = self._augment_mincut_capacity() + + gr = self.weighted_graph_raw + src, tgt = gr.vertex(self.source_graph_ids[0]), gr.vertex( + self.sink_graph_ids[0] + ) + + residuals = graph_tool.flow.boykov_kolmogorov_max_flow( + gr, src, tgt, adj_capacity + ) + + partition = graph_tool.flow.min_st_cut(gr, src, adj_capacity, residuals) + return partition + + def compute_mincut(self): + """ + Compute mincut and return the supervoxel cut edge set + """ + + time_start = time.time() + + if self.path_augment: + partition = self._compute_mincut_path_augmented() + else: + partition = self._compute_mincut_direct() + + dt = time.time() - time_start + if self.logger is not None: + self.logger.debug("Mincut comp: %.2fms" % (dt * 1000)) + + if DEBUG_MODE: + self._gt_mincut_sanity_check(partition) + + if self.path_augment: + labeled_edges = partition.a[self.gt_edges_raw] + cut_edge_set = self.gt_edges_raw[labeled_edges[:, 0] != labeled_edges[:, 1]] + else: + labeled_edges = partition.a[self.gt_edges] + cut_edge_set = self.gt_edges[labeled_edges[:, 0] != labeled_edges[:, 1]] + if self.split_preview: + return self._get_split_preview_connected_components(cut_edge_set) + + self._sink_and_source_connectivity_sanity_check(cut_edge_set) + return self._remap_cut_edge_set(cut_edge_set) + + def _remap_cut_edge_set(self, cut_edge_set): + """ + Remap the cut edge set from graph ids to supervoxel ids and return it + """ + remapped_cutset = [] + for s, t in flatgraph.remap_ids_from_graph( + cut_edge_set, self.unique_supervoxel_ids + ): + + if s in self.cross_chunk_edge_remapping: + s = self.cross_chunk_edge_remapping[s] + else: + s = [s] + + if t in self.cross_chunk_edge_remapping: + t = self.cross_chunk_edge_remapping[t] + else: + t = [t] + + remapped_cutset.extend(list(itertools.product(s, t))) + remapped_cutset.extend(list(itertools.product(t, s))) + + remapped_cutset = np.array(remapped_cutset, dtype=np.uint64) + + remapped_cutset_flattened_view = remapped_cutset.view(dtype="u8,u8") + edges_flattened_view = self.cg_edges.view(dtype="u8,u8") + + cutset_mask = np.in1d(remapped_cutset_flattened_view, edges_flattened_view) + + return remapped_cutset[cutset_mask] + + def _remap_graph_ids_to_cg_supervoxels(self, graph_ids): + supervoxel_list = [] + # Supervoxels that were passed into graph + mapped_supervoxels = self.unique_supervoxel_ids[graph_ids] + # Now need to remap these using the cross_chunk_edge_remapping + for sv in mapped_supervoxels: + if sv in self.cross_chunk_edge_remapping: + supervoxel_list.extend(self.cross_chunk_edge_remapping[sv]) + else: + supervoxel_list.append(sv) + return np.array(supervoxel_list) + + def _get_split_preview_connected_components(self, cut_edge_set): + """ + Return the connected components of the local graph (in terms of supervoxels) + when doing a split preview + """ + ( + ccs_test_post_cut, + illegal_split, + ) = self._sink_and_source_connectivity_sanity_check(cut_edge_set) + supervoxel_ccs = [None] * len(ccs_test_post_cut) + # Return a list of connected components where the first component always contains + # the most sources and the second always contains the most sinks (to make life easier for Neuroglancer) + max_source_index = -1 + max_sources = 0 + max_sink_index = -1 + max_sinks = 0 + i = 0 + for cc in ccs_test_post_cut: + num_sources = np.count_nonzero(np.in1d(self.source_graph_ids, cc)) + num_sinks = np.count_nonzero(np.in1d(self.sink_graph_ids, cc)) + if num_sources > max_sources: + max_sources = num_sources + max_source_index = i + if num_sinks > max_sinks: + max_sinks = num_sinks + max_sink_index = i + i += 1 + supervoxel_ccs[0] = self._remap_graph_ids_to_cg_supervoxels( + ccs_test_post_cut[max_source_index] + ) + supervoxel_ccs[1] = self._remap_graph_ids_to_cg_supervoxels( + ccs_test_post_cut[max_sink_index] + ) + i = 0 + j = 2 + for cc in ccs_test_post_cut: + if i != max_source_index and i != max_sink_index: + supervoxel_ccs[j] = self._remap_graph_ids_to_cg_supervoxels(cc) + j += 1 + i += 1 + return (supervoxel_ccs, illegal_split) + + def _create_fake_edge_property(self, affs): + """ + Create an edge property to remove fake edges later + (will be used to test whether split valid) + """ + is_fake_edge = np.concatenate( + [ + [False] * len(affs), + [True] * (len(self.source_edges) + len(self.sink_edges)), + ] + ) + remove_edges_later = np.concatenate([is_fake_edge, is_fake_edge]) + self.edges_to_remove = self.weighted_graph.new_edge_property( + "bool", vals=remove_edges_later + ) + + def _filter_graph_connected_components(self): + """ + Filter out connected components in the graph + that are not involved in the local mincut + """ + ccs = flatgraph.connected_components(self.weighted_graph) + + removed = self.weighted_graph.new_vertex_property("bool") + removed.a = False + if len(ccs) > 1: + for cc in ccs: + # If connected component contains no sources or no sinks, + # remove its nodes from the mincut computation + if not ( + np.any(np.in1d(self.source_graph_ids, cc)) + and np.any(np.in1d(self.sink_graph_ids, cc)) + ): + for node_id in cc: + removed[node_id] = True + + self.weighted_graph.set_vertex_filter(removed, inverted=True) + pruned_graph = graph_tool.Graph(self.weighted_graph, prune=True) + # Test that there is only one connected component left + ccs = flatgraph.connected_components(pruned_graph) + if len(ccs) > 1: + if self.logger is not None: + self.logger.warning( + "Not all sinks and sources are within the same (local)" + "connected component" + ) + raise PreconditionError( + "Not all sinks and sources are within the same (local)" + "connected component" + ) + elif len(ccs) == 0: + raise PreconditionError( + "Sinks and sources are not connected through the local graph. " + "Please try a different set of vertices to perform the mincut." + ) + + def _gt_mincut_sanity_check(self, partition): + """ + After the mincut has been computed, assert that: the sources are within + one connected component, and the sinks are within another separate one. + These assertions should not fail. If they do, + then something went wrong with the graph_tool mincut computation + """ + for i_cc in np.unique(partition.a): + # Make sure to read real ids and not graph ids + cc_list = self.unique_supervoxel_ids[ + np.array(np.where(partition.a == i_cc)[0], dtype=int) + ] + + if np.any(np.in1d(self.sources, cc_list)): + assert np.all(np.in1d(self.sources, cc_list)) + assert ~np.any(np.in1d(self.sinks, cc_list)) + + if np.any(np.in1d(self.sinks, cc_list)): + assert np.all(np.in1d(self.sinks, cc_list)) + assert ~np.any(np.in1d(self.sources, cc_list)) + + def _sink_and_source_connectivity_sanity_check(self, cut_edge_set): + """ + Similar to _gt_mincut_sanity_check, except we do the check again *after* + removing the fake infinite affinity edges. + """ + time_start = time.time() + for cut_edge in cut_edge_set: + # May be more than one edge from vertex cut_edge[0] to vertex cut_edge[1], remove them all + parallel_edges = self.weighted_graph.edge( + cut_edge[0], cut_edge[1], all_edges=True + ) + for edge_to_remove in parallel_edges: + self.edges_to_remove[edge_to_remove] = True + + self.weighted_graph.set_edge_filter(self.edges_to_remove, True) + ccs_test_post_cut = flatgraph.connected_components(self.weighted_graph) + + # Make sure sinks and sources are among each other and not in different sets + # after removing the cut edges and the fake infinity edges + illegal_split = False + try: + for cc in ccs_test_post_cut: + if np.any(np.in1d(self.source_graph_ids, cc)): + assert np.all(np.in1d(self.source_graph_ids, cc)) + assert ~np.any(np.in1d(self.sink_graph_ids, cc)) + if ( + len(self.source_path_vertices) == len(cc) + and self.disallow_isolating_cut + ): + if not self.partition_edges_within_label(cc): + raise IsolatingCutException("Source") + + if np.any(np.in1d(self.sink_graph_ids, cc)): + assert np.all(np.in1d(self.sink_graph_ids, cc)) + assert ~np.any(np.in1d(self.source_graph_ids, cc)) + if ( + len(self.sink_path_vertices) == len(cc) + and self.disallow_isolating_cut + ): + if not self.partition_edges_within_label(cc): + raise IsolatingCutException("Sink") + + except AssertionError: + if self.split_preview: + # If we are performing a split preview, we allow these illegal splits, + # but return a flag to return a message to the user + illegal_split = True + else: + raise PreconditionError( + "Failed to find a cut that separated the sources from the sinks. " + "Please try another cut that partitions the sets cleanly if possible. " + "If there is a clear path between all the supervoxels in each set, " + "that helps the mincut algorithm." + ) + except IsolatingCutException as e: + if self.split_preview: + illegal_split = True + else: + raise PreconditionError( + f"Split cut off only the labeled points on the {e} side. Please additional points to other parts of the merge error on that side." + ) + + dt = time.time() - time_start + if self.logger is not None: + self.logger.debug("Verifying local graph: %.2fms" % (dt * 1000)) + return ccs_test_post_cut, illegal_split + + def partition_edges_within_label(self, cc): + """Test is an isolated component has out-edges only within the original + labeled points of the cut + """ + label_graph_ids = np.concatenate((self.source_graph_ids, self.sink_graph_ids)) + + for vind in cc: + v = self.weighted_graph_raw.vertex(vind) + out_vinds = [int(x) for x in v.out_neighbors()] + if not np.all(np.isin(out_vinds, label_graph_ids)): + return False + else: + return True + + +def run_multicut( + edges: Edges, + source_ids: Sequence[np.uint64], + sink_ids: Sequence[np.uint64], + *, + split_preview: bool = False, + path_augment: bool = True, + disallow_isolating_cut: bool = True, +): + local_mincut_graph = LocalMincutGraph( + edges.get_pairs(), + edges.affinities, + source_ids, + sink_ids, + split_preview, + path_augment, + disallow_isolating_cut=disallow_isolating_cut, + ) + atomic_edges = local_mincut_graph.compute_mincut() + if len(atomic_edges) == 0: + raise PostconditionError(f"Mincut failed. Try with a different set of points.") + return atomic_edges + + +def run_split_preview( + cg, + source_ids: Sequence[np.uint64], + sink_ids: Sequence[np.uint64], + source_coords: Sequence[Sequence[int]], + sink_coords: Sequence[Sequence[int]], + bb_offset: Tuple[int, int, int] = (120, 120, 12), + path_augment: bool = True, + disallow_isolating_cut: bool = True, +): + root_ids = set( + cg.get_roots(np.concatenate([source_ids, sink_ids]), assert_roots=True) + ) + if len(root_ids) > 1: + raise PreconditionError("Supervoxels must belong to the same object.") + + bbox = get_bounding_box(source_coords, sink_coords, bb_offset) + l2id_agglomeration_d, edges = cg.get_subgraph( + root_ids.pop(), bbox=bbox, bbox_is_coordinate=True + ) + in_edges, out_edges, cross_edges = edges + edges = in_edges + out_edges + cross_edges + supervoxels = np.concatenate( + [agg.supervoxels for agg in l2id_agglomeration_d.values()] + ) + mask0 = np.in1d(edges.node_ids1, supervoxels) + mask1 = np.in1d(edges.node_ids2, supervoxels) + edges = edges[mask0 & mask1] + edges_to_remove, illegal_split = run_multicut( + edges, + source_ids, + sink_ids, + split_preview=True, + path_augment=path_augment, + disallow_isolating_cut=disallow_isolating_cut, + ) + + if len(edges_to_remove) == 0: + raise PostconditionError("Mincut could not find any edges to remove.") + + return edges_to_remove, illegal_split diff --git a/pychunkedgraph/graph/edges/__init__.py b/pychunkedgraph/graph/edges/__init__.py new file mode 100644 index 000000000..b0e488d05 --- /dev/null +++ b/pychunkedgraph/graph/edges/__init__.py @@ -0,0 +1,105 @@ +""" +Classes and types for edges +""" + +from typing import Optional +from collections import namedtuple + +import numpy as np + +from ..utils import basetypes + + +_edge_type_fileds = ("in_chunk", "between_chunk", "cross_chunk") +_edge_type_defaults = ("in", "between", "cross") + +EdgeTypes = namedtuple("EdgeTypes", _edge_type_fileds, defaults=_edge_type_defaults) +EDGE_TYPES = EdgeTypes() + +DEFAULT_AFFINITY = np.finfo(np.float32).tiny +DEFAULT_AREA = np.finfo(np.float32).tiny + + +class Edges: + def __init__( + self, + node_ids1: np.ndarray, + node_ids2: np.ndarray, + *, + affinities: Optional[np.ndarray] = None, + areas: Optional[np.ndarray] = None, + fake_edges=False, + ): + self.node_ids1 = np.array(node_ids1, dtype=basetypes.NODE_ID, copy=False) + self.node_ids2 = np.array(node_ids2, dtype=basetypes.NODE_ID, copy=False) + assert self.node_ids1.size == self.node_ids2.size + + self._as_pairs = None + self._fake_edges = fake_edges + + if affinities is not None and len(affinities) > 0: + self._affinities = np.array(affinities, dtype=basetypes.EDGE_AFFINITY, copy=False) + assert self.node_ids1.size == self._affinities.size + else: + self._affinities = np.full(len(self.node_ids1), DEFAULT_AFFINITY) + + if areas is not None and len(areas) > 0: + self._areas = np.array(areas, dtype=basetypes.EDGE_AREA, copy=False) + assert self.node_ids1.size == self._areas.size + else: + self._areas = np.full(len(self.node_ids1), DEFAULT_AREA) + + @property + def affinities(self) -> np.ndarray: + return self._affinities + + @affinities.setter + def affinities(self, affinities): + self._affinities = affinities + + @property + def areas(self) -> np.ndarray: + return self._areas + + @areas.setter + def areas(self, areas): + self._areas = areas + + def __add__(self, other): + """add two Edges instances""" + node_ids1 = np.concatenate([self.node_ids1, other.node_ids1]) + node_ids2 = np.concatenate([self.node_ids2, other.node_ids2]) + affinities = np.concatenate([self.affinities, other.affinities]) + areas = np.concatenate([self.areas, other.areas]) + return Edges(node_ids1, node_ids2, affinities=affinities, areas=areas) + + def __iadd__(self, other): + self.node_ids1 = np.concatenate([self.node_ids1, other.node_ids1]) + self.node_ids2 = np.concatenate([self.node_ids2, other.node_ids2]) + self.affinities = np.concatenate([self.affinities, other.affinities]) + self.areas = np.concatenate([self.areas, other.areas]) + return self + + def __len__(self): + return self.node_ids1.size + + def __getitem__(self, key): + """`key` must be a boolean numpy array.""" + try: + return Edges( + self.node_ids1[key], + self.node_ids2[key], + affinities=self.affinities[key], + areas=self.areas[key], + ) + except Exception as err: + raise (err) + + def get_pairs(self) -> np.ndarray: + """ + return numpy array of edge pairs [[sv1, sv2] ... ] + """ + if not self._as_pairs is None: + return self._as_pairs + self._as_pairs = np.column_stack((self.node_ids1, self.node_ids2)) + return self._as_pairs diff --git a/pychunkedgraph/graph/edges/utils.py b/pychunkedgraph/graph/edges/utils.py new file mode 100644 index 000000000..034ca6ebc --- /dev/null +++ b/pychunkedgraph/graph/edges/utils.py @@ -0,0 +1,216 @@ +# pylint: disable=invalid-name, missing-docstring, c-extension-no-member + +""" +helper functions for edge stuff +""" + +from typing import Dict +from typing import Tuple +from typing import Iterable +from typing import Optional + +import fastremap +import numpy as np + +from . import Edges +from . import EDGE_TYPES +from ..types import empty_2d +from ..utils import basetypes +from ..chunks import utils as chunk_utils +from ..meta import ChunkedGraphMeta + + +def concatenate_chunk_edges(chunk_edge_dicts: Iterable) -> Dict: + """combine edge_dicts of multiple chunks into one edge_dict""" + edges_dict = {} + for edge_type in EDGE_TYPES: + sv_ids1 = [np.array([], dtype=basetypes.NODE_ID)] + sv_ids2 = [np.array([], dtype=basetypes.NODE_ID)] + affinities = [np.array([], dtype=basetypes.EDGE_AFFINITY)] + areas = [np.array([], dtype=basetypes.EDGE_AREA)] + for edge_d in chunk_edge_dicts: + edges = edge_d[edge_type] + sv_ids1.append(edges.node_ids1) + sv_ids2.append(edges.node_ids2) + affinities.append(edges.affinities) + areas.append(edges.areas) + + sv_ids1 = np.concatenate(sv_ids1) + sv_ids2 = np.concatenate(sv_ids2) + affinities = np.concatenate(affinities) + areas = np.concatenate(areas) + edges_dict[edge_type] = Edges( + sv_ids1, sv_ids2, affinities=affinities, areas=areas + ) + return edges_dict + + +def concatenate_cross_edge_dicts(edges_ds: Iterable[Dict]) -> Dict: + """Combines cross chunk edge dicts of form {layer id : edge list}.""" + from collections import defaultdict + + result_d = defaultdict(list) + + for edges_d in edges_ds: + for layer, edges in edges_d.items(): + result_d[layer].append(edges) + + for layer, edge_lists in result_d.items(): + result_d[layer] = np.concatenate(edge_lists) + return result_d + + +def merge_cross_edge_dicts(x_edges_d1: Dict, x_edges_d2: Dict) -> Dict: + """ + Combines two cross chunk dictionaries of form + {node_id: {layer id : edge list}}. + """ + node_ids = np.unique(list(x_edges_d1.keys()) + list(x_edges_d2.keys())) + result_d = {} + for node_id in node_ids: + cross_edge_ds = [x_edges_d1.get(node_id, {}), x_edges_d2.get(node_id, {})] + result_d[node_id] = concatenate_cross_edge_dicts(cross_edge_ds) + return result_d + + +def categorize_edges( + meta: ChunkedGraphMeta, supervoxels: np.ndarray, edges: Edges +) -> Tuple[Edges, Edges, Edges]: + """ + Find edges and categorize them into: + `in_edges` + between given supervoxels + (sv1, sv2) - sv1 in supervoxels and sv2 in supervoxels + `out_edges` + originating from given supervoxels but within chunk + (sv1, sv2) - sv1 in supervoxels and sv2 not in supervoxels + `cross_edges` + originating from given supervoxels but crossing chunk boundary + """ + mask1 = np.isin(edges.node_ids1, supervoxels) + mask2 = np.isin(edges.node_ids2, supervoxels) + in_mask = mask1 & mask2 + out_mask = mask1 & ~mask2 + + in_edges = edges[in_mask] + all_out_edges = edges[out_mask] # out_edges + cross_edges + + edge_layers = get_cross_chunk_edges_layer(meta, all_out_edges.get_pairs()) + cross_edges_mask = edge_layers > 1 + out_edges = all_out_edges[~cross_edges_mask] + cross_edges = all_out_edges[cross_edges_mask] + return (in_edges, out_edges, cross_edges) + + +def categorize_edges_v2( + meta: ChunkedGraphMeta, + edges: Edges, + sv_parent_d: Dict, +) -> Tuple[Edges, Edges, Edges]: + """Faster version of categorize_edges(), avoids looping over L2 IDs.""" + + node_ids1 = fastremap.remap( + edges.node_ids1, sv_parent_d, preserve_missing_labels=True + ) + node_ids2 = fastremap.remap( + edges.node_ids2, sv_parent_d, preserve_missing_labels=True + ) + + layer_mask1 = chunk_utils.get_chunk_layers(meta, node_ids1) > 1 + nodes_mask = node_ids1 == node_ids2 + + in_edges = edges[nodes_mask] + all_out_ = edges[layer_mask1 & ~nodes_mask] + + cx_layers = get_cross_chunk_edges_layer(meta, all_out_.get_pairs()) + + cx_mask = cx_layers > 1 + out_edges = all_out_[~cx_mask] + cross_edges = all_out_[cx_mask] + return (in_edges, out_edges, cross_edges) + + +def get_cross_chunk_edges_layer(meta: ChunkedGraphMeta, cross_edges: Iterable): + """Computes the layer in which a cross chunk edge becomes relevant. + I.e. if a cross chunk edge links two nodes in layer 4 this function + returns 3. + :param cross_edges: n x 2 array + edges between atomic (level 1) node ids + :return: array of length n + """ + if len(cross_edges) == 0: + return np.array([], dtype=int) + cross_chunk_edge_layers = np.ones(len(cross_edges), dtype=int) + coords0 = chunk_utils.get_chunk_coordinates_multiple(meta, cross_edges[:, 0]) + coords1 = chunk_utils.get_chunk_coordinates_multiple(meta, cross_edges[:, 1]) + + for _ in range(2, meta.layer_count): + edge_diff = np.sum(np.abs(coords0 - coords1), axis=1) + cross_chunk_edge_layers[edge_diff > 0] += 1 + coords0 = coords0 // meta.graph_config.FANOUT + coords1 = coords1 // meta.graph_config.FANOUT + return cross_chunk_edge_layers + + +def filter_min_layer_cross_edges( + meta: ChunkedGraphMeta, cross_edges_d: Dict, node_layer: int = 2 +) -> Tuple[int, Iterable]: + """ + Given a dict of cross chunk edges {layer: edges} + Return the first layer with cross edges. + """ + for layer in range(node_layer, meta.layer_count): + edges_ = cross_edges_d.get(layer, empty_2d) + if edges_.size: + return (layer, edges_) + return (meta.layer_count, edges_) + + +def filter_min_layer_cross_edges_multiple( + meta: ChunkedGraphMeta, l2id_atomic_cross_edges_ds: Iterable, node_layer: int = 2 +) -> Tuple[int, Iterable]: + """ + Given a list of dicts of cross chunk edges [{layer: edges}] + Return the first layer with cross edges. + """ + min_layer = meta.layer_count + for edges_d in l2id_atomic_cross_edges_ds: + layer_, _ = filter_min_layer_cross_edges(meta, edges_d, node_layer=node_layer) + min_layer = min(min_layer, layer_) + edges = [empty_2d] + for edges_d in l2id_atomic_cross_edges_ds: + edges.append(edges_d.get(min_layer, empty_2d)) + return min_layer, np.concatenate(edges) + + +def get_edges_status(cg, edges: Iterable, time_stamp: Optional[float] = None): + from ...utils.general import in2d + + coords0 = chunk_utils.get_chunk_coordinates_multiple(cg.meta, edges[:, 0]) + coords1 = chunk_utils.get_chunk_coordinates_multiple(cg.meta, edges[:, 1]) + + coords = np.concatenate([np.array(coords0), np.array(coords1)]) + bbox = [np.min(coords, axis=0), np.max(coords, axis=0)] + bbox[1] += 1 + + root_ids = set( + cg.get_roots(edges.ravel(), assert_roots=True, time_stamp=time_stamp) + ) + sg_edges = cg.get_subgraph( + root_ids, + bbox=bbox, + bbox_is_coordinate=False, + edges_only=True, + ) + existence_status = in2d(edges, sg_edges) + edge_layers = cg.get_cross_chunk_edges_layer(edges) + active_status = [] + for layer in np.unique(edge_layers): + layer_edges = edges[edge_layers == layer] + edges_parents = cg.get_roots( + layer_edges.ravel(), time_stamp=time_stamp, stop_layer=layer + 1 + ).reshape(-1, 2) + mask = edges_parents[:, 0] == edges_parents[:, 1] + active_status.extend(mask) + active_status = np.array(active_status, dtype=bool) + return existence_status, active_status diff --git a/pychunkedgraph/graph/edits.py b/pychunkedgraph/graph/edits.py new file mode 100644 index 000000000..be2eee1c6 --- /dev/null +++ b/pychunkedgraph/graph/edits.py @@ -0,0 +1,599 @@ +# pylint: disable=invalid-name, missing-docstring, too-many-locals, c-extension-no-member + +import datetime +from typing import Dict +from typing import List +from typing import Tuple +from typing import Iterable +from collections import defaultdict + +import numpy as np +import fastremap + +from . import types +from . import attributes +from . import cache as cache_utils +from .edges.utils import concatenate_cross_edge_dicts +from .edges.utils import merge_cross_edge_dicts +from .utils import basetypes +from .utils import flatgraph +from .utils.serializers import serialize_uint64 +from ..logging.log_db import TimeIt +from ..utils.general import in2d + + +def _init_old_hierarchy(cg, l2ids: np.ndarray, parent_ts: datetime.datetime = None): + new_old_id_d = defaultdict(set) + old_new_id_d = defaultdict(set) + old_hierarchy_d = {id_: {2: id_} for id_ in l2ids} + for id_ in l2ids: + layer_parent_d = cg.get_all_parents_dict(id_, time_stamp=parent_ts) + old_hierarchy_d[id_].update(layer_parent_d) + for parent in layer_parent_d.values(): + old_hierarchy_d[parent] = old_hierarchy_d[id_] + return new_old_id_d, old_new_id_d, old_hierarchy_d + + +def _analyze_affected_edges( + cg, atomic_edges: Iterable[np.ndarray], parent_ts: datetime.datetime = None +) -> Tuple[Iterable, Dict]: + """ + Determine if atomic edges are within the chunk. + If not, they are cross edges between two L2 IDs in adjacent chunks. + Returns edges between L2 IDs and atomic cross edges. + """ + supervoxels = np.unique(atomic_edges) + parents = cg.get_parents(supervoxels, time_stamp=parent_ts) + sv_parent_d = dict(zip(supervoxels.tolist(), parents)) + edge_layers = cg.get_cross_chunk_edges_layer(atomic_edges) + parent_edges = [ + [sv_parent_d[edge_[0]], sv_parent_d[edge_[1]]] + for edge_ in atomic_edges[edge_layers == 1] + ] + + # cross chunk edges + atomic_cross_edges_d = defaultdict(lambda: defaultdict(list)) + for layer in range(2, cg.meta.layer_count): + layer_edges = atomic_edges[edge_layers == layer] + if not layer_edges.size: + continue + for edge in layer_edges: + parent_1 = sv_parent_d[edge[0]] + parent_2 = sv_parent_d[edge[1]] + atomic_cross_edges_d[parent_1][layer].append(edge) + atomic_cross_edges_d[parent_2][layer].append(edge[::-1]) + parent_edges.extend([[parent_1, parent_1], [parent_2, parent_2]]) + return (parent_edges, atomic_cross_edges_d) + + +def _get_relevant_components(edges: np.ndarray, supervoxels: np.ndarray) -> Tuple: + edges = np.concatenate([edges, np.vstack([supervoxels, supervoxels]).T]) + graph, _, _, graph_ids = flatgraph.build_gt_graph(edges, make_directed=True) + ccs = flatgraph.connected_components(graph) + relevant_ccs = [] + # remove if connected component contains no sources or no sinks + # when merging, there must be only two components + for cc_idx in ccs: + cc = graph_ids[cc_idx] + if np.any(np.in1d(supervoxels, cc)): + relevant_ccs.append(cc) + assert len(relevant_ccs) == 2, "must be 2 components" + return relevant_ccs + + +def merge_preprocess( + cg, + *, + subgraph_edges: np.ndarray, + supervoxels: np.ndarray, + parent_ts: datetime.datetime = None, +) -> np.ndarray: + """ + Determine if a fake edge needs to be added. + Get subgraph within the bounding box + Add fake edge if there are no inactive edges between two components. + """ + edge_layers = cg.get_cross_chunk_edges_layer(subgraph_edges) + active_edges = [types.empty_2d] + inactive_edges = [types.empty_2d] + for layer in np.unique(edge_layers): + _edges = subgraph_edges[edge_layers == layer] + edge_nodes = fastremap.unique(_edges.ravel()) + roots = cg.get_roots(edge_nodes, time_stamp=parent_ts, stop_layer=layer + 1) + parent_map = dict(zip(edge_nodes, roots)) + parent_edges = fastremap.remap(_edges, parent_map, preserve_missing_labels=True) + + active_mask = parent_edges[:, 0] == parent_edges[:, 1] + active, inactive = _edges[active_mask], _edges[~active_mask] + active_edges.append(active) + inactive_edges.append(inactive) + + relevant_ccs = _get_relevant_components(np.concatenate(active_edges), supervoxels) + inactive = np.concatenate(inactive_edges) + _inactive = [types.empty_2d] + # source to sink edges + source_mask = np.in1d(inactive[:, 0], relevant_ccs[0]) + sink_mask = np.in1d(inactive[:, 1], relevant_ccs[1]) + _inactive.append(inactive[source_mask & sink_mask]) + + # sink to source edges + sink_mask = np.in1d(inactive[:, 1], relevant_ccs[0]) + source_mask = np.in1d(inactive[:, 0], relevant_ccs[1]) + _inactive.append(inactive[source_mask & sink_mask]) + _inactive = np.concatenate(_inactive) + return np.unique(_inactive, axis=0) if _inactive.size else types.empty_2d + + +def check_fake_edges( + cg, + *, + atomic_edges: Iterable[np.ndarray], + inactive_edges: Iterable[np.ndarray], + time_stamp: datetime.datetime, + parent_ts: datetime.datetime = None, +) -> Tuple[Iterable[np.ndarray], Iterable]: + """if no inactive edges found, add user input as fake edge.""" + if inactive_edges.size: + roots = np.unique( + cg.get_roots( + np.unique(inactive_edges), + assert_roots=True, + time_stamp=parent_ts, + ) + ) + assert len(roots) == 2, "edges must be from 2 roots" + print("found inactive", len(inactive_edges)) + return inactive_edges, [] + + rows = [] + supervoxels = atomic_edges.ravel() + chunk_ids = cg.get_chunk_ids_from_node_ids( + cg.get_parents(supervoxels, time_stamp=parent_ts) + ) + sv_l2chunk_id_d = dict(zip(supervoxels.tolist(), chunk_ids)) + for edge in atomic_edges: + id1, id2 = sv_l2chunk_id_d[edge[0]], sv_l2chunk_id_d[edge[1]] + val_dict = {} + val_dict[attributes.Connectivity.FakeEdges] = np.array( + [[edge]], dtype=basetypes.NODE_ID + ) + id1 = serialize_uint64(id1, fake_edges=True) + rows.append( + cg.client.mutate_row( + id1, + val_dict, + time_stamp=time_stamp, + ) + ) + val_dict = {} + val_dict[attributes.Connectivity.FakeEdges] = np.array( + [edge[::-1]], dtype=basetypes.NODE_ID + ) + id2 = serialize_uint64(id2, fake_edges=True) + rows.append( + cg.client.mutate_row( + id2, + val_dict, + time_stamp=time_stamp, + ) + ) + print("no inactive", len(atomic_edges)) + return atomic_edges, rows + + +def add_edges( + cg, + *, + atomic_edges: Iterable[np.ndarray], + operation_id: np.uint64 = None, + time_stamp: datetime.datetime = None, + parent_ts: datetime.datetime = None, + allow_same_segment_merge=False, +): + edges, l2_atomic_cross_edges_d = _analyze_affected_edges( + cg, atomic_edges, parent_ts=parent_ts + ) + l2ids = np.unique(edges) + if not allow_same_segment_merge: + assert ( + np.unique(cg.get_roots(l2ids, assert_roots=True, time_stamp=parent_ts)).size + == 2 + ), "L2 IDs must belong to different roots." + new_old_id_d, old_new_id_d, old_hierarchy_d = _init_old_hierarchy( + cg, l2ids, parent_ts=parent_ts + ) + atomic_children_d = cg.get_children(l2ids) + atomic_cross_edges_d = merge_cross_edge_dicts( + cg.get_atomic_cross_edges(l2ids), l2_atomic_cross_edges_d + ) + + graph, _, _, graph_ids = flatgraph.build_gt_graph(edges, make_directed=True) + components = flatgraph.connected_components(graph) + new_l2_ids = [] + for cc_indices in components: + l2ids_ = graph_ids[cc_indices] + new_id = cg.id_client.create_node_id(cg.get_chunk_id(l2ids_[0])) + cg.cache.children_cache[new_id] = np.concatenate( + [atomic_children_d[l2id] for l2id in l2ids_] + ) + cg.cache.atomic_cx_edges_cache[new_id] = concatenate_cross_edge_dicts( + [atomic_cross_edges_d[l2id] for l2id in l2ids_] + ) + cache_utils.update( + cg.cache.parents_cache, cg.cache.children_cache[new_id], new_id + ) + new_l2_ids.append(new_id) + new_old_id_d[new_id].update(l2ids_) + for id_ in l2ids_: + old_new_id_d[id_].add(new_id) + + create_parents = CreateParentNodes( + cg, + new_l2_ids=new_l2_ids, + old_hierarchy_d=old_hierarchy_d, + new_old_id_d=new_old_id_d, + old_new_id_d=old_new_id_d, + operation_id=operation_id, + time_stamp=time_stamp, + parent_ts=parent_ts, + ) + + new_roots = create_parents.run() + new_entries = create_parents.create_new_entries() + return new_roots, new_l2_ids, new_entries + + +def _process_l2_agglomeration( + agg: types.Agglomeration, + removed_edges: np.ndarray, + atomic_cross_edges_d: Dict[int, np.ndarray], +): + """ + For a given L2 id, remove given edges + and calculate new connected components. + """ + chunk_edges = agg.in_edges.get_pairs() + cross_edges = np.concatenate([types.empty_2d, *atomic_cross_edges_d.values()]) + chunk_edges = chunk_edges[~in2d(chunk_edges, removed_edges)] + cross_edges = cross_edges[~in2d(cross_edges, removed_edges)] + + isolated_ids = agg.supervoxels[~np.in1d(agg.supervoxels, chunk_edges)] + isolated_edges = np.column_stack((isolated_ids, isolated_ids)) + graph, _, _, graph_ids = flatgraph.build_gt_graph( + np.concatenate([chunk_edges, isolated_edges]), make_directed=True + ) + return flatgraph.connected_components(graph), graph_ids, cross_edges + + +def _filter_component_cross_edges( + cc_ids: np.ndarray, cross_edges: np.ndarray, cross_edge_layers: np.ndarray +) -> Dict[int, np.ndarray]: + """ + Filters cross edges for a connected component `cc_ids` + from `cross_edges` of the complete chunk. + """ + mask = np.in1d(cross_edges[:, 0], cc_ids) + cross_edges_ = cross_edges[mask] + cross_edge_layers_ = cross_edge_layers[mask] + edges_d = {} + for layer in np.unique(cross_edge_layers_): + edge_m = cross_edge_layers_ == layer + _cross_edges = cross_edges_[edge_m] + if _cross_edges.size: + edges_d[layer] = _cross_edges + return edges_d + + +def remove_edges( + cg, + *, + atomic_edges: Iterable[np.ndarray], + l2id_agglomeration_d: Dict, + operation_id: basetypes.OPERATION_ID = None, + time_stamp: datetime.datetime = None, + parent_ts: datetime.datetime = None, +): + edges, _ = _analyze_affected_edges(cg, atomic_edges, parent_ts=parent_ts) + l2ids = np.unique(edges) + assert ( + np.unique(cg.get_roots(l2ids, assert_roots=True, time_stamp=parent_ts)).size + == 1 + ), "L2 IDs must belong to same root." + new_old_id_d, old_new_id_d, old_hierarchy_d = _init_old_hierarchy( + cg, l2ids, parent_ts=parent_ts + ) + l2id_chunk_id_d = dict(zip(l2ids.tolist(), cg.get_chunk_ids_from_node_ids(l2ids))) + atomic_cross_edges_d = cg.get_atomic_cross_edges(l2ids) + + removed_edges = np.concatenate([atomic_edges, atomic_edges[:, ::-1]], axis=0) + new_l2_ids = [] + for id_ in l2ids: + l2_agg = l2id_agglomeration_d[id_] + ccs, graph_ids, cross_edges = _process_l2_agglomeration( + l2_agg, removed_edges, atomic_cross_edges_d[id_] + ) + # calculated here to avoid repeat computation in loop + cross_edge_layers = cg.get_cross_chunk_edges_layer(cross_edges) + new_parent_ids = cg.id_client.create_node_ids( + l2id_chunk_id_d[l2_agg.node_id], len(ccs) + ) + for i_cc, cc in enumerate(ccs): + new_id = new_parent_ids[i_cc] + cg.cache.children_cache[new_id] = graph_ids[cc] + cg.cache.atomic_cx_edges_cache[new_id] = _filter_component_cross_edges( + graph_ids[cc], cross_edges, cross_edge_layers + ) + cache_utils.update(cg.cache.parents_cache, graph_ids[cc], new_id) + new_l2_ids.append(new_id) + new_old_id_d[new_id].add(id_) + old_new_id_d[id_].add(new_id) + + create_parents = CreateParentNodes( + cg, + new_l2_ids=new_l2_ids, + old_hierarchy_d=old_hierarchy_d, + new_old_id_d=new_old_id_d, + old_new_id_d=old_new_id_d, + operation_id=operation_id, + time_stamp=time_stamp, + parent_ts=parent_ts, + ) + new_roots = create_parents.run() + new_entries = create_parents.create_new_entries() + return new_roots, new_l2_ids, new_entries + + +class CreateParentNodes: + def __init__( + self, + cg, + *, + new_l2_ids: Iterable, + operation_id: basetypes.OPERATION_ID, + time_stamp: datetime.datetime, + new_old_id_d: Dict[np.uint64, Iterable[np.uint64]] = None, + old_new_id_d: Dict[np.uint64, Iterable[np.uint64]] = None, + old_hierarchy_d: Dict[np.uint64, Dict[int, np.uint64]] = None, + parent_ts: datetime.datetime = None, + ): + self.cg = cg + self._new_l2_ids = new_l2_ids + self._old_hierarchy_d = old_hierarchy_d + self._new_old_id_d = new_old_id_d + self._old_new_id_d = old_new_id_d + self._new_ids_d = defaultdict(list) # new IDs in each layer + self._cross_edges_d = {} + self._operation_id = operation_id + self._time_stamp = time_stamp + self._last_successful_ts = parent_ts + + def _update_id_lineage( + self, + parent: basetypes.NODE_ID, + children: np.ndarray, + layer: int, + parent_layer: int, + ): + mask = np.in1d(children, self._new_ids_d[layer]) + for child_id in children[mask]: + child_old_ids = self._new_old_id_d[child_id] + for id_ in child_old_ids: + old_id = self._old_hierarchy_d[id_].get(parent_layer, id_) + self._new_old_id_d[parent].add(old_id) + self._old_new_id_d[old_id].add(parent) + + def _get_old_ids(self, new_ids): + old_ids = [ + np.array(list(self._new_old_id_d[id_]), dtype=basetypes.NODE_ID) + for id_ in new_ids + ] + return np.concatenate(old_ids) + + def _map_sv_to_parent(self, node_ids, layer, node_map=None): + sv_parent_d = {} + sv_cross_edges = [types.empty_2d] + if node_map is None: + node_map = {} + for id_ in node_ids: + id_eff = node_map.get(id_, id_) + edges_ = self._cross_edges_d[id_].get(layer, types.empty_2d) + sv_parent_d.update(dict(zip(edges_[:, 0], [id_eff] * len(edges_)))) + sv_cross_edges.append(edges_) + return sv_parent_d, np.concatenate(sv_cross_edges) + + def _get_connected_components( + self, node_ids: np.ndarray, layer: int, lower_layer_ids: np.ndarray + ): + _node_ids = np.concatenate([node_ids, lower_layer_ids]) + cached = np.fromiter(self._cross_edges_d.keys(), dtype=basetypes.NODE_ID) + not_cached = _node_ids[~np.in1d(_node_ids, cached)] + + with TimeIt( + f"get_cross_chunk_edges.{layer}", + self.cg.graph_id, + self._operation_id, + ): + self._cross_edges_d.update( + self.cg.get_cross_chunk_edges(not_cached, all_layers=True) + ) + + sv_parent_d, sv_cross_edges = self._map_sv_to_parent(node_ids, layer) + get_sv_parents = np.vectorize(sv_parent_d.get, otypes=[np.uint64]) + try: + cross_edges = get_sv_parents(sv_cross_edges) + except TypeError: # NoneType error + # if there is a missing parent, try including lower layer ids + # this can happen due to skip connections + + # we want to map all these lower IDs to the current layer + lower_layer_to_layer = self.cg.get_roots( + lower_layer_ids, stop_layer=layer, ceil=False + ) + node_map = {k: v for k, v in zip(lower_layer_ids, lower_layer_to_layer)} + sv_parent_d, sv_cross_edges = self._map_sv_to_parent( + _node_ids, layer, node_map=node_map + ) + get_sv_parents = np.vectorize(sv_parent_d.get, otypes=[np.uint64]) + cross_edges = get_sv_parents(sv_cross_edges) + + cross_edges = np.concatenate([cross_edges, np.vstack([node_ids, node_ids]).T]) + graph, _, _, graph_ids = flatgraph.build_gt_graph( + cross_edges, make_directed=True + ) + return flatgraph.connected_components(graph), graph_ids + + def _get_layer_node_ids( + self, new_ids: np.ndarray, layer: int + ) -> Tuple[np.ndarray, np.ndarray]: + # get old identities of new IDs + old_ids = self._get_old_ids(new_ids) + # get their parents, then children of those parents + node_ids = self.cg.get_children( + np.unique( + self.cg.get_parents(old_ids, time_stamp=self._last_successful_ts) + ), + flatten=True, + ) + # replace old identities with new IDs + mask = np.in1d(node_ids, old_ids) + node_ids = np.concatenate( + [ + np.array(list(self._old_new_id_d[id_]), dtype=basetypes.NODE_ID) + for id_ in node_ids[mask] + ] + + [node_ids[~mask], new_ids] + ) + node_ids = np.unique(node_ids) + layer_mask = self.cg.get_chunk_layers(node_ids) == layer + return node_ids[layer_mask], node_ids[~layer_mask] + + def _create_new_parents(self, layer: int): + """ + keep track of old IDs + merge - one new ID has 2 old IDs + split - two/more new IDs have the same old ID + get parents of old IDs, their children are the siblings + those siblings include old IDs, replace with new + get cross edges of all, find connected components + update parent old IDs + """ + new_ids = self._new_ids_d[layer] + layer_node_ids, lower_layer_ids = self._get_layer_node_ids(new_ids, layer) + components, graph_ids = self._get_connected_components( + layer_node_ids, layer, lower_layer_ids + ) + for cc_indices in components: + parent_layer = layer + 1 + cc_ids = graph_ids[cc_indices] + if len(cc_ids) == 1: + # skip connection + parent_layer = self.cg.meta.layer_count + for l in range(layer + 1, self.cg.meta.layer_count): + if len(self._cross_edges_d[cc_ids[0]].get(l, types.empty_2d)) > 0: + parent_layer = l + break + + parent_id = self.cg.id_client.create_node_id( + self.cg.get_parent_chunk_id(cc_ids[0], parent_layer), + root_chunk=parent_layer == self.cg.meta.layer_count, + ) + self._new_ids_d[parent_layer].append(parent_id) + self.cg.cache.children_cache[parent_id] = cc_ids + cache_utils.update( + self.cg.cache.parents_cache, + cc_ids, + parent_id, + ) + self._update_id_lineage(parent_id, cc_ids, layer, parent_layer) + + def run(self) -> Iterable: + """ + After new level 2 IDs are created, create parents in higher layers. + Cross edges are used to determine existing siblings. + """ + self._new_ids_d[2] = self._new_l2_ids + for layer in range(2, self.cg.meta.layer_count): + if len(self._new_ids_d[layer]) == 0: + continue + with TimeIt( + f"create_new_parents_layer.{layer}", + self.cg.graph_id, + self._operation_id, + ): + self._create_new_parents(layer) + return self._new_ids_d[self.cg.meta.layer_count] + + def _update_root_id_lineage(self): + new_root_ids = self._new_ids_d[self.cg.meta.layer_count] + former_root_ids = self._get_old_ids(new_root_ids) + former_root_ids = np.unique(former_root_ids) + assert ( + len(former_root_ids) < 2 or len(new_root_ids) < 2 + ), "Something went wrong." + rows = [] + for new_root_id in new_root_ids: + val_dict = { + attributes.Hierarchy.FormerParent: np.array(former_root_ids), + attributes.OperationLogs.OperationID: self._operation_id, + } + rows.append( + self.cg.client.mutate_row( + serialize_uint64(new_root_id), + val_dict, + time_stamp=self._time_stamp, + ) + ) + + for former_root_id in former_root_ids: + val_dict = { + attributes.Hierarchy.NewParent: np.array(new_root_ids), + attributes.OperationLogs.OperationID: self._operation_id, + } + rows.append( + self.cg.client.mutate_row( + serialize_uint64(former_root_id), + val_dict, + time_stamp=self._time_stamp, + ) + ) + return rows + + def _get_atomic_cross_edges_val_dict(self): + new_ids = np.array(self._new_ids_d[2], dtype=basetypes.NODE_ID) + val_dicts = {} + atomic_cross_edges_d = self.cg.get_atomic_cross_edges(new_ids) + for id_ in new_ids: + val_dict = {} + for layer, edges in atomic_cross_edges_d[id_].items(): + val_dict[attributes.Connectivity.CrossChunkEdge[layer]] = edges + val_dicts[id_] = val_dict + return val_dicts + + def create_new_entries(self) -> List: + rows = [] + val_dicts = self._get_atomic_cross_edges_val_dict() + for layer in range(2, self.cg.meta.layer_count + 1): + new_ids = self._new_ids_d[layer] + for id_ in new_ids: + val_dict = val_dicts.get(id_, {}) + children = self.cg.get_children(id_) + assert np.max( + self.cg.get_chunk_layers(children) + ) < self.cg.get_chunk_layer(id_), "Parent layer less than children." + val_dict[attributes.Hierarchy.Child] = children + rows.append( + self.cg.client.mutate_row( + serialize_uint64(id_), + val_dict, + time_stamp=self._time_stamp, + ) + ) + for child_id in children: + rows.append( + self.cg.client.mutate_row( + serialize_uint64(child_id), + {attributes.Hierarchy.Parent: id_}, + time_stamp=self._time_stamp, + ) + ) + return rows + self._update_root_id_lineage() diff --git a/pychunkedgraph/backend/chunkedgraph_exceptions.py b/pychunkedgraph/graph/exceptions.py similarity index 99% rename from pychunkedgraph/backend/chunkedgraph_exceptions.py rename to pychunkedgraph/graph/exceptions.py index dd3213f41..45aa57fc7 100644 --- a/pychunkedgraph/backend/chunkedgraph_exceptions.py +++ b/pychunkedgraph/graph/exceptions.py @@ -15,6 +15,7 @@ class PreconditionError(ChunkedGraphError): """Raised when preconditions for Chunked Graph operations are not met""" pass + class PostconditionError(ChunkedGraphError): """Raised when postconditions for Chunked Graph operations are not met""" pass diff --git a/pychunkedgraph/graph/lineage.py b/pychunkedgraph/graph/lineage.py new file mode 100644 index 000000000..6876ec563 --- /dev/null +++ b/pychunkedgraph/graph/lineage.py @@ -0,0 +1,227 @@ +""" +Functions for tracking root ID changes over time. +""" +from typing import Union +from typing import Optional +from typing import Iterable +from datetime import datetime +from collections import defaultdict + +import numpy as np +from networkx import DiGraph + +from . import attributes +from .exceptions import ChunkedGraphError +from .attributes import Hierarchy +from .attributes import OperationLogs +from .utils.basetypes import NODE_ID +from .utils.generic import get_min_time +from .utils.generic import get_max_time +from .utils.generic import get_valid_timestamp + + +def get_latest_root_id(cg, root_id: NODE_ID.type) -> np.ndarray: + """Returns the latest root id associated with the provided root id""" + id_working_set = [root_id] + latest_root_ids = [] + while len(id_working_set) > 0: + next_id = id_working_set[0] + del id_working_set[0] + node = cg.client.read_node(next_id, properties=attributes.Hierarchy.NewParent) + # Check if a new root id was attached to this root id + if node: + id_working_set.extend(node[0].value) + else: + latest_root_ids.append(next_id) + return np.unique(latest_root_ids) + + +def get_future_root_ids( + cg, + root_id: NODE_ID, + time_stamp: Optional[datetime] = get_max_time(), +) -> np.ndarray: + """ + Returns all future root ids emerging from this root + This search happens in a monotic fashion. At no point are past root + ids of future root ids taken into account. + """ + id_history = [] + next_ids = [root_id] + while len(next_ids): + temp_next_ids = [] + for next_id in next_ids: + node = cg.client.read_node( + next_id, + properties=[attributes.Hierarchy.NewParent, attributes.Hierarchy.Child], + ) + if attributes.Hierarchy.NewParent in node: + ids = node[attributes.Hierarchy.NewParent][0].value + row_time_stamp = node[attributes.Hierarchy.NewParent][0].timestamp + elif attributes.Hierarchy.Child in node: + ids = None + row_time_stamp = node[attributes.Hierarchy.Child][0].timestamp + else: + raise ChunkedGraphError(f"Error retrieving future root ID of {next_id}") + if row_time_stamp < get_valid_timestamp(time_stamp): + if ids is not None: + temp_next_ids.extend(ids) + if next_id != root_id: + id_history.append(next_id) + next_ids = temp_next_ids + return np.unique(np.array(id_history, dtype=NODE_ID)) + + +def get_past_root_ids( + cg, + root_id: NODE_ID, + time_stamp: Optional[datetime] = get_min_time(), +) -> np.ndarray: + """ + Returns all past root ids emerging from this root. + This search happens in a monotic fashion. At no point are future root + ids of past root ids taken into account. + """ + id_history = [] + next_ids = [root_id] + while len(next_ids): + temp_next_ids = [] + for next_id in next_ids: + node = cg.client.read_node( + next_id, + properties=[ + attributes.Hierarchy.FormerParent, + attributes.Hierarchy.Child, + ], + ) + if attributes.Hierarchy.FormerParent in node: + ids = node[attributes.Hierarchy.FormerParent][0].value + row_time_stamp = node[attributes.Hierarchy.FormerParent][0].timestamp + elif attributes.Hierarchy.Child in node: + ids = None + row_time_stamp = node[attributes.Hierarchy.Child][0].timestamp + else: + raise ChunkedGraphError(f"Error retrieving past root ID of {next_id}.") + if row_time_stamp > get_valid_timestamp(time_stamp): + if ids is not None: + temp_next_ids.extend(ids) + if next_id != root_id: + id_history.append(next_id) + next_ids = temp_next_ids + return np.unique(np.array(id_history, dtype=NODE_ID)) + + +def get_previous_root_ids( + cg, + root_ids: Iterable[NODE_ID.type], +) -> dict: + """Returns immediate former root IDs (1 step history)""" + nodes_d = cg.client.read_nodes( + node_ids=root_ids, + properties=attributes.Hierarchy.FormerParent, + ) + result = {} + for root, val in nodes_d.items(): + result[root] = val[0].value + return result + + +def get_root_id_history( + cg, + root_id: NODE_ID, + time_stamp_past: Optional[datetime] = get_min_time(), + time_stamp_future: Optional[datetime] = get_max_time(), +) -> np.ndarray: + """ + Returns all future root ids emerging from this root + This search happens in a monotic fashion. At no point are future root + ids of past root ids or past root ids of future root ids taken into + account. + """ + past_ids = get_past_root_ids(cg, root_id, time_stamp=time_stamp_past) + future_ids = get_future_root_ids(cg, root_id, time_stamp=time_stamp_future) + return np.concatenate([past_ids, np.array([root_id], dtype=NODE_ID), future_ids]) + + +def _get_node_properties(node_entry: dict) -> dict: + node_d = {} + node_d["timestamp"] = node_entry[Hierarchy.Child][0].timestamp.timestamp() + if OperationLogs.OperationID in node_entry: + if len(node_entry[OperationLogs.OperationID]) == 2 or ( + len(node_entry[OperationLogs.OperationID]) == 1 + and Hierarchy.NewParent in node_entry + ): + node_d["operation_id"] = node_entry[OperationLogs.OperationID][0].value + return node_d + + +def lineage_graph( + cg, + node_ids: Union[int, Iterable[int]], + timestamp_past: Optional[datetime] = None, + timestamp_future: Optional[datetime] = None, +) -> DiGraph: + """ + Build lineage graph of a given root ID + going backwards in time until `timestamp_past` + and in future until `timestamp_future` + """ + if not isinstance(node_ids, np.ndarray) and not isinstance(node_ids, list): + node_ids = [node_ids] + + graph = DiGraph() + past_ids = np.array(node_ids, dtype=NODE_ID) + future_ids = np.array(node_ids, dtype=NODE_ID) + timestamp_past = float(0) if timestamp_past is None else timestamp_past.timestamp() + timestamp_future = ( + datetime.utcnow().timestamp() + if timestamp_future is None + else timestamp_future.timestamp() + ) + + while past_ids.size or future_ids.size: + nodes_raw = cg.client.read_nodes( + node_ids=np.unique(np.concatenate([past_ids, future_ids])) + ) + next_past_ids = [] + for k in past_ids: + val = nodes_raw[k] + node_d = _get_node_properties(val) + graph.add_node(k, **node_d) + if ( + node_d["timestamp"] < timestamp_past + or not Hierarchy.FormerParent in val + ): + continue + former_ids = val[Hierarchy.FormerParent][0].value + next_past_ids.extend( + [former_id for former_id in former_ids if not former_id in graph.nodes] + ) + for former in former_ids: + graph.add_edge(former, k) + + next_future_ids = [] + future_operation_id_dict = defaultdict(list) + for k in future_ids: + val = nodes_raw[k] + node_d = _get_node_properties(val) + graph.add_node(k, **node_d) + if node_d["timestamp"] > timestamp_future or not Hierarchy.NewParent in val: + continue + try: + future_operation_id_dict[node_d["operation_id"]].append(k) + except KeyError: + pass + + logs_raw = cg.client.read_log_entries(list(future_operation_id_dict.keys())) + for operation_id in future_operation_id_dict: + new_ids = logs_raw[operation_id][OperationLogs.RootID] + next_future_ids.extend( + [new_id for new_id in new_ids if not new_id in graph.nodes] + ) + for new_id in new_ids: + for k in future_operation_id_dict[operation_id]: + graph.add_edge(k, new_id) + past_ids = np.array(np.unique(next_past_ids), dtype=NODE_ID) + future_ids = np.array(np.unique(next_future_ids), dtype=NODE_ID) + return graph diff --git a/pychunkedgraph/graph/locks.py b/pychunkedgraph/graph/locks.py new file mode 100644 index 000000000..b3a3a0eb7 --- /dev/null +++ b/pychunkedgraph/graph/locks.py @@ -0,0 +1,134 @@ +from typing import Union +from typing import Sequence +from collections import defaultdict + +import numpy as np + +from . import exceptions +from .types import empty_1d +from .lineage import get_future_root_ids + + +class RootLock: + """Attempts to lock the requested root IDs using a unique operation ID. + :raises exceptions.LockingError: throws when one or more root ID locks could not be + acquired. + """ + + __slots__ = [ + "cg", + "root_ids", + "locked_root_ids", + "lock_acquired", + "operation_id", + "privileged_mode", + ] + # FIXME: `locked_root_ids` is only required and exposed because `cg.client.lock_roots` + # currently might lock different (more recent) root IDs than requested. + + def __init__( + self, + cg, + root_ids: Union[np.uint64, Sequence[np.uint64]], + *, + operation_id: np.uint64 = None, + privileged_mode: bool = False, + ) -> None: + self.cg = cg + self.root_ids = np.atleast_1d(root_ids) + self.locked_root_ids = [] + self.lock_acquired = False + self.operation_id = operation_id + # `privileged_mode` if True, override locking. + # This is intended to be used in extremely rare cases to fix errors + # caused by failed writes. Must be used with `operation_id`, + # meaning only existing failed operations can be run this way. + self.privileged_mode = privileged_mode + + def __enter__(self): + if self.privileged_mode: + assert self.operation_id is not None, "Please provide operation ID." + from warnings import warn + + warn("Warning: Privileged mode without acquiring lock.") + return self + if not self.operation_id: + self.operation_id = self.cg.id_client.create_operation_id() + + future_root_ids_d = defaultdict(lambda: empty_1d) + for id_ in self.root_ids: + future_root_ids_d[id_] = get_future_root_ids(self.cg, id_) + + self.lock_acquired, self.locked_root_ids = self.cg.client.lock_roots( + root_ids=self.root_ids, + operation_id=self.operation_id, + future_root_ids_d=future_root_ids_d, + max_tries=7, + ) + if not self.lock_acquired: + raise exceptions.LockingError("Could not acquire root lock") + return self + + def __exit__(self, exception_type, exception_value, traceback): + if self.lock_acquired: + for locked_root_id in self.locked_root_ids: + self.cg.client.unlock_root(locked_root_id, self.operation_id) + + +class IndefiniteRootLock: + """ + Attempts to lock the requested root IDs using a unique operation ID. + Assumes the root IDs have already been locked temporally. + Also renews temporal lock before creating locking indefinitely, + fails to lock indefinitely if the temporal lock cannot be re-acquired. + + :raises exceptions.LockingError: + when a root ID lock cannot be renewed + or when it has already been locked indefinitely. + """ + + __slots__ = ["cg", "root_ids", "acquired", "operation_id", "privileged_mode"] + + def __init__( + self, + cg, + operation_id: np.uint64, + root_ids: Union[np.uint64, Sequence[np.uint64]], + privileged_mode: bool = False, + ) -> None: + self.cg = cg + self.operation_id = operation_id + self.root_ids = np.atleast_1d(root_ids) + self.acquired = False + # `privileged_mode` if True, override locking. + # This is intended to be used in extremely rare cases to fix errors + # caused by failed writes. + self.privileged_mode = privileged_mode + + def __enter__(self): + if self.privileged_mode: + from warnings import warn + + warn("Warning: Privileged mode without acquiring indefinite lock.") + return self + if not self.cg.client.renew_locks(self.root_ids, self.operation_id): + raise exceptions.LockingError("Could not renew locks before writing.") + + future_root_ids_d = defaultdict(lambda: empty_1d) + for id_ in self.root_ids: + future_root_ids_d[id_] = get_future_root_ids(self.cg, id_) + self.acquired, self.root_ids, failed = self.cg.client.lock_roots_indefinitely( + root_ids=self.root_ids, + operation_id=self.operation_id, + future_root_ids_d=future_root_ids_d, + ) + if not self.acquired: + raise exceptions.LockingError(f"{failed} has been locked indefinitely.") + return self + + def __exit__(self, exception_type, exception_value, traceback): + if self.acquired: + for locked_root_id in self.root_ids: + self.cg.client.unlock_indefinitely_locked_root( + locked_root_id, self.operation_id + ) diff --git a/pychunkedgraph/graph/meta.py b/pychunkedgraph/graph/meta.py new file mode 100644 index 000000000..83d670ffe --- /dev/null +++ b/pychunkedgraph/graph/meta.py @@ -0,0 +1,290 @@ +import json +from datetime import timedelta +from typing import Dict +from typing import List +from typing import Tuple +from typing import Sequence +from collections import namedtuple + +import numpy as np +from cloudvolume import CloudVolume + +from .utils.generic import compute_bitmasks +from .chunks.utils import get_chunks_boundary +from ..utils.redis import keys as r_keys +from ..utils.redis import get_rq_queue +from ..utils.redis import get_redis_connection + + +_datasource_fields = ("EDGES", "COMPONENTS", "WATERSHED", "DATA_VERSION", "CV_MIP") +_datasource_defaults = (None, None, None, None, 0) +DataSource = namedtuple( + "DataSource", + _datasource_fields, + defaults=_datasource_defaults, +) + + +_graphconfig_fields = ( + "ID", # ID_PREFIX and ID are together used when creating the graph + "ID_PREFIX", + "CHUNK_SIZE", + "FANOUT", + "LAYER_ID_BITS", # number of bits reserved for layer id + "SPATIAL_BITS", # number of bits used for each spatial in id creation on level 1 + "OVERWRITE", # overwrites existing graph + "ROOT_LOCK_EXPIRY", + "ROOT_COUNTERS", +) +_graphconfig_defaults = ( + None, + "", + None, + 2, + 8, + 10, + False, + timedelta(minutes=3, seconds=0), + 8, +) +GraphConfig = namedtuple( + "GraphConfig", _graphconfig_fields, defaults=_graphconfig_defaults +) + + +class ChunkedGraphMeta: + def __init__( + self, graph_config: GraphConfig, data_source: DataSource, custom_data: Dict = {} + ): + """ + `custom_data`: stores arbitray key value information, for flexibility. + """ + self._graph_config = graph_config + self._data_source = data_source + self._custom_data = custom_data + + self._ws_cv = None + self._layer_bounds_d = None + self._layer_count = None + self._bitmasks = None + + @property + def graph_config(self): + return self._graph_config + + @property + def data_source(self): + return self._data_source + + @property + def custom_data(self): + return self._custom_data + + @property + def ws_cv(self): + if self._ws_cv: + return self._ws_cv + + cache_key = f"{self.graph_config.ID}:ws_cv_info_cached" + try: + # try reading a cached info file for distributed workers + # useful to avoid md5 errors on high gcs load + redis = get_redis_connection() + cached_info = json.loads(redis.get(cache_key)) + self._ws_cv = CloudVolume(self._data_source.WATERSHED, info=cached_info) + except Exception: + self._ws_cv = CloudVolume(self._data_source.WATERSHED) + try: + redis.set(cache_key, json.dumps(self._ws_cv.info)) + except Exception: + ... + return self._ws_cv + + @property + def resolution(self): + return self.ws_cv.resolution # pylint: disable=no-member + + @property + def layer_count(self) -> int: + from .utils.generic import log_n + + if self._layer_count: + return self._layer_count + bbox = np.array(self.ws_cv.bounds.to_list()) # pylint: disable=no-member + bbox = bbox.reshape(2, 3) + n_chunks = get_chunks_boundary( + self.voxel_counts, np.array(self._graph_config.CHUNK_SIZE, dtype=int) + ) + self._layer_count = ( + int(np.ceil(log_n(np.max(n_chunks), self._graph_config.FANOUT))) + 2 + ) + return self._layer_count + + @layer_count.setter + def layer_count(self, count): + self._layer_count = count + self._bitmasks = compute_bitmasks( + self._layer_count, + s_bits_atomic_layer=self._graph_config.SPATIAL_BITS, + ) + + @property + def cv(self): + """Alias for watershed CV""" + return self.ws_cv + + @property + def bitmasks(self): + if self._bitmasks: + return self._bitmasks + self._bitmasks = compute_bitmasks( + self.layer_count, + s_bits_atomic_layer=self._graph_config.SPATIAL_BITS, + ) + return self._bitmasks + + @property + def voxel_bounds(self): + bounds = np.array(self.ws_cv.bounds.to_list()) # pylint: disable=no-member + return bounds.reshape(2, -1).T + + @property + def voxel_counts(self) -> Sequence[int]: + """returns number of voxels in each dimension""" + cv_bounds = np.array(self.ws_cv.bounds.to_list()) # pylint: disable=no-member + cv_bounds = cv_bounds.reshape(2, -1).T + voxel_counts = cv_bounds.copy() + voxel_counts -= cv_bounds[:, 0:1] # pylint: disable=unsubscriptable-object + voxel_counts = voxel_counts[:, 1] + return voxel_counts + + @property + def layer_chunk_bounds(self) -> Dict: + """number of chunks in each dimension in each layer {layer: [x,y,z]}""" + if self._layer_bounds_d: + return self._layer_bounds_d + + chunks_boundary = get_chunks_boundary( + self.voxel_counts, np.array(self._graph_config.CHUNK_SIZE, dtype=int) + ) + layer_bounds_d = {} + for layer in range(2, self.layer_count): + layer_bounds = chunks_boundary / (2 ** (layer - 2)) + layer_bounds_d[layer] = np.ceil(layer_bounds).astype(int) + self._layer_bounds_d = layer_bounds_d + return self._layer_bounds_d + + @layer_chunk_bounds.setter + def layer_chunk_bounds(self, layer_chunk_bounds_d): + self._layer_bounds_d = layer_chunk_bounds_d + + @property + def layer_chunk_counts(self) -> List: + """number of chunks in each layer""" + counts = [] + for layer in range(2, self.layer_count): + counts.append(np.prod(self.layer_chunk_bounds[layer])) + return counts + [1] + + @property + def edge_dtype(self): + if self.data_source.DATA_VERSION == 4: + dtype = [ + ("sv1", np.uint64), + ("sv2", np.uint64), + ("aff_x", np.float32), + ("area_x", np.uint64), + ("aff_y", np.float32), + ("area_y", np.uint64), + ("aff_z", np.float32), + ("area_z", np.uint64), + ] + elif self.data_source.DATA_VERSION == 3: + dtype = [ + ("sv1", np.uint64), + ("sv2", np.uint64), + ("aff_x", np.float64), + ("area_x", np.uint64), + ("aff_y", np.float64), + ("area_y", np.uint64), + ("aff_z", np.float64), + ("area_z", np.uint64), + ] + elif self.data_source.DATA_VERSION == 2: + dtype = [ + ("sv1", np.uint64), + ("sv2", np.uint64), + ("aff", np.float32), + ("area", np.uint64), + ] + else: + raise Exception() + return dtype + + @property + def READ_ONLY(self): + return self.custom_data.get("READ_ONLY", False) + + @property + def split_bounding_offset(self): + return self.custom_data.get( + "split_bounding_offset", + (240, 240, 24), + ) + + @property + def dataset_info(self) -> Dict: + info = self.ws_cv.info # pylint: disable=no-member + + info.update( + { + "chunks_start_at_voxel_offset": True, + "data_dir": self.data_source.WATERSHED, + "graph": { + "chunk_size": self.graph_config.CHUNK_SIZE, + "bounding_box": [2048, 2048, 512], + "n_bits_for_layer_id": self.graph_config.LAYER_ID_BITS, + "cv_mip": self.data_source.CV_MIP, + "n_layers": self.layer_count, + "spatial_bit_masks": self.bitmasks, + }, + } + ) + mesh_dir = self.custom_data.get("mesh", {}).get("dir", None) + if mesh_dir is not None: + info.update({"mesh": mesh_dir}) + return info + + def __getnewargs__(self): + return (self.graph_config, self.data_source) + + def __getstate__(self): + return { + "graph_config": self.graph_config, + "data_source": self.data_source, + "custom_data": self.custom_data, + } + + def __setstate__(self, state): + self.__init__( + state["graph_config"], state["data_source"], state.get("custom_data", {}) + ) + + def __str__(self): + from json import dumps + + meta_str = f"GRAPH_CONFIG\n{self.graph_config}\n" + meta_str += f"\nDATA_SOURCE\n{self.data_source}\n" + meta_str += f"\nCUSTOM_DATA\n{self.custom_data}\n" + meta_str += f"\nBITMASKS\n{self.bitmasks}\n" + meta_str += f"\nVOXEL_BOUNDS\n{self.voxel_bounds}\n" + meta_str += f"\nVOXEL_COUNTS\n{self.voxel_counts}\n" + meta_str += f"\nLAYER_CHUNK_BOUNDS\n{self.layer_chunk_bounds}\n" + meta_str += f"\nLAYER_CHUNK_COUNTS\n{self.layer_chunk_counts}\n" + meta_str += f"\nDATASET_INFO\n{dumps(self.dataset_info, indent=4)}\n" + return meta_str + + def is_out_of_bounds(self, chunk_coordinate): + return np.any(chunk_coordinate < 0) or np.any( + chunk_coordinate > 2 ** self.bitmasks[1] + ) diff --git a/pychunkedgraph/graph/misc.py b/pychunkedgraph/graph/misc.py new file mode 100644 index 000000000..b33e8a6fd --- /dev/null +++ b/pychunkedgraph/graph/misc.py @@ -0,0 +1,290 @@ +# pylint: disable=invalid-name, missing-docstring, c-extension-no-member, import-outside-toplevel + +import datetime +import collections +from typing import Dict +from typing import Optional +from typing import Sequence + +import fastremap +import numpy as np +from multiwrapper import multiprocessing_utils as mu + +from . import ChunkedGraph +from . import attributes +from .edges import Edges +from .utils import flatgraph +from .types import Agglomeration + + +def _read_delta_root_rows( + cg, + start_id, + end_id, + time_stamp_start, + time_stamp_end, +) -> Sequence[list]: + # apply column filters to avoid Lock columns + rows = cg.client.read_nodes( + start_id=start_id, + start_time=time_stamp_start, + end_id=end_id, + end_id_inclusive=False, + properties=[attributes.Hierarchy.FormerParent, attributes.Hierarchy.NewParent], + end_time=time_stamp_end, + end_time_inclusive=True, + ) + + # new roots are those that have no NewParent in this time window + new_root_ids = [ + k for (k, v) in rows.items() if attributes.Hierarchy.NewParent not in v + ] + + # expired roots are the IDs of FormerParent's + # whose timestamp is before the start_time + expired_root_ids = [] + for v in rows.values(): + if attributes.Hierarchy.FormerParent in v: + fp = v[attributes.Hierarchy.FormerParent] + for cell_entry in fp: + expired_root_ids.extend(cell_entry.value) + return new_root_ids, expired_root_ids + + +def _read_root_rows_thread(args) -> list: + start_seg_id, end_seg_id, serialized_cg_info, time_stamp = args + cg = ChunkedGraph(**serialized_cg_info) + start_id = cg.get_node_id(segment_id=start_seg_id, chunk_id=cg.root_chunk_id) + end_id = cg.get_node_id(segment_id=end_seg_id, chunk_id=cg.root_chunk_id) + rows = cg.client.read_nodes( + start_id=start_id, + end_id=end_id, + end_id_inclusive=False, + end_time=time_stamp, + end_time_inclusive=True, + ) + root_ids = [k for (k, v) in rows.items() if attributes.Hierarchy.NewParent not in v] + return root_ids + + +def get_proofread_root_ids( + cg: ChunkedGraph, + start_time: Optional[datetime.datetime] = None, + end_time: Optional[datetime.datetime] = None, +): + log_entries = cg.client.read_log_entries( + start_time=start_time, + end_time=end_time, + properties=[attributes.OperationLogs.RootID], + end_time_inclusive=True, + ) + root_chunks = [e[attributes.OperationLogs.RootID] for e in log_entries.values()] + if len(root_chunks) == 0: + return np.array([], dtype=np.uint64), np.array([], dtype=np.int64) + new_roots = np.concatenate(root_chunks) + + root_rows = cg.client.read_nodes( + node_ids=new_roots, properties=[attributes.Hierarchy.FormerParent] + ) + old_root_chunks = [np.empty(0, dtype=np.uint64)] + for e in root_rows.values(): + old_root_chunks.append(e[attributes.Hierarchy.FormerParent][0].value) + old_roots = np.concatenate(old_root_chunks) + return old_roots, new_roots + + +def get_latest_roots( + cg, time_stamp: Optional[datetime.datetime] = None, n_threads: int = 1 +) -> Sequence[np.uint64]: + # Create filters: time and id range + max_seg_id = cg.get_max_seg_id(cg.root_chunk_id) + 1 + n_blocks = 1 if n_threads == 1 else int(np.min([n_threads * 3 + 1, max_seg_id])) + seg_id_blocks = np.linspace(1, max_seg_id, n_blocks + 1, dtype=np.uint64) + cg_serialized_info = cg.get_serialized_info() + if n_threads > 1: + del cg_serialized_info["credentials"] + + multi_args = [] + for i_id_block in range(0, len(seg_id_blocks) - 1): + multi_args.append( + [ + seg_id_blocks[i_id_block], + seg_id_blocks[i_id_block + 1], + cg_serialized_info, + time_stamp, + ] + ) + + if n_threads == 1: + results = mu.multiprocess_func( + _read_root_rows_thread, + multi_args, + n_threads=n_threads, + verbose=False, + debug=n_threads == 1, + ) + else: + results = mu.multisubprocess_func( + _read_root_rows_thread, multi_args, n_threads=n_threads + ) + root_ids = [] + for result in results: + root_ids.extend(result) + return np.array(root_ids, dtype=np.uint64) + + +def get_delta_roots( + cg: ChunkedGraph, + time_stamp_start: datetime.datetime, + time_stamp_end: Optional[datetime.datetime] = None, +) -> Sequence[np.uint64]: + # Create filters: time and id range + start_id = np.uint64(cg.get_chunk_id(layer=cg.meta.layer_count) + 1) + end_id = cg.id_client.get_max_node_id( + cg.get_chunk_id(layer=cg.meta.layer_count), root_chunk=True + ) + np.uint64(1) + new_root_ids, expired_root_id_candidates = _read_delta_root_rows( + cg, start_id, end_id, time_stamp_start, time_stamp_end + ) + + # aggregate all the results together + new_root_ids = np.array(new_root_ids, dtype=np.uint64) + expired_root_id_candidates = np.array(expired_root_id_candidates, dtype=np.uint64) + # filter for uniqueness + expired_root_id_candidates = np.unique(expired_root_id_candidates) + + # filter out the expired root id's whose creation (measured by the timestamp + # of their Child links) is after the time_stamp_start + rows = cg.client.read_nodes( + node_ids=expired_root_id_candidates, + properties=[attributes.Hierarchy.Child], + end_time=time_stamp_start, + ) + expired_root_ids = np.array([k for (k, v) in rows.items()], dtype=np.uint64) + return np.array(new_root_ids, dtype=np.uint64), expired_root_ids + + +def get_contact_sites( + cg: ChunkedGraph, + root_id, + bounding_box=None, + bbox_is_coordinate=True, + compute_partner=True, + time_stamp=None, +): + # Get information about the root id + # All supervoxels + sv_ids = cg.get_subgraph( + root_id, + bbox=bounding_box, + bbox_is_coordinate=bbox_is_coordinate, + nodes_only=True, + return_flattened=True, + ) + # All edges that are _not_ connected / on + edges, _, areas = cg.get_subgraph_edges( + root_id, + bbox=bounding_box, + bbox_is_coordinate=bbox_is_coordinate, + connected_edges=False, + ) + + # Build area lookup dictionary + cs_svs = edges[~np.in1d(edges, sv_ids).reshape(-1, 2)] + area_dict = collections.defaultdict(int) + + for area, sv_id in zip(areas, cs_svs): + area_dict[sv_id] += area + + area_dict_vec = np.vectorize(area_dict.get) + # Extract svs from contacting root ids + u_cs_svs = np.unique(cs_svs) + # Load edges of these cs_svs + edges_cs_svs_rows = cg.client.read_nodes( + node_ids=u_cs_svs, + # columns=[attributes.Connectivity.Partner, attributes.Connectivity.Connected], + ) + pre_cs_edges = [] + for ri in edges_cs_svs_rows.items(): + r = cg._retrieve_connectivity(ri) + pre_cs_edges.extend(r[0]) + graph, _, _, unique_ids = flatgraph.build_gt_graph(pre_cs_edges, make_directed=True) + # connected components in this graph will be combined in one component + ccs = flatgraph.connected_components(graph) + cs_dict = collections.defaultdict(list) + for cc in ccs: + cc_sv_ids = unique_ids[cc] + cc_sv_ids = cc_sv_ids[np.in1d(cc_sv_ids, u_cs_svs)] + cs_areas = area_dict_vec(cc_sv_ids) + partner_root_id = ( + int(cg.get_root(cc_sv_ids[0], time_stamp=time_stamp)) + if compute_partner + else len(cs_dict) + ) + cs_dict[partner_root_id].append(np.sum(cs_areas)) + return cs_dict + + +def get_agglomerations( + l2id_children_d: Dict, + in_edges: Edges, + ot_edges: Edges, + cx_edges: Edges, + sv_parent_d: Dict, +) -> Dict[np.uint64, Agglomeration]: + l2id_agglomeration_d = {} + _in = fastremap.remap(in_edges.node_ids1, sv_parent_d, preserve_missing_labels=True) + _ot = fastremap.remap(ot_edges.node_ids1, sv_parent_d, preserve_missing_labels=True) + _cx = fastremap.remap(cx_edges.node_ids1, sv_parent_d, preserve_missing_labels=True) + for l2id in l2id_children_d: + l2id_agglomeration_d[l2id] = Agglomeration( + l2id, + l2id_children_d[l2id], + in_edges[_in == l2id], + ot_edges[_ot == l2id], + cx_edges[_cx == l2id], + ) + return l2id_agglomeration_d + + +def get_activated_edges( + cg: ChunkedGraph, operation_id: int, delta: Optional[int] = 100 +) -> np.ndarray: + """ + Returns edges that were made active by a merge operation. + """ + from datetime import timedelta + from .edits import merge_preprocess + from .operation import GraphEditOperation + from .operation import MergeOperation + from .utils.generic import get_bounding_box as get_bbox + + log, time_stamp = cg.client.read_log_entry(operation_id) + assert ( + GraphEditOperation.get_log_record_type(log) == MergeOperation + ), "Must be a merge operation." + + time_stamp -= timedelta(milliseconds=delta) + operation = GraphEditOperation.from_log_record(cg, log) + bbox = get_bbox( + operation.source_coords, operation.sink_coords, operation.bbox_offset + ) + + root_ids = set( + cg.get_roots( + operation.added_edges.ravel(), assert_roots=True, time_stamp=time_stamp + ) + ) + assert len(root_ids) > 1, "More than one segment is required for merge." + edges = operation.cg.get_subgraph( + root_ids, + bbox=bbox, + bbox_is_coordinate=True, + edges_only=True, + ) + return merge_preprocess( + cg, + subgraph_edges=edges, + supervoxels=operation.added_edges.ravel(), + parent_ts=time_stamp, + ) diff --git a/pychunkedgraph/graph/operation.py b/pychunkedgraph/graph/operation.py new file mode 100644 index 000000000..d0d0e172a --- /dev/null +++ b/pychunkedgraph/graph/operation.py @@ -0,0 +1,1257 @@ +# pylint: disable=invalid-name, missing-docstring, too-many-lines, protected-access + +from abc import ABC, abstractmethod +from collections import namedtuple +from datetime import datetime +from typing import TYPE_CHECKING +from typing import Dict +from typing import List +from typing import Type +from typing import Tuple +from typing import Union +from typing import Optional +from typing import Sequence +from functools import reduce + +import numpy as np +from google.cloud import bigtable + +from . import locks +from . import edits +from . import types +from . import attributes +from .edges import Edges +from .edges.utils import get_edges_status +from .utils import basetypes +from .utils import serializers +from .cache import CacheService +from .cutting import run_multicut +from .exceptions import PreconditionError +from .exceptions import PostconditionError +from .utils.generic import get_bounding_box as get_bbox +from ..logging.log_db import TimeIt + + +if TYPE_CHECKING: + from .chunkedgraph import ChunkedGraph + + +class GraphEditOperation(ABC): + __slots__ = [ + "cg", + "user_id", + "source_coords", + "sink_coords", + "parent_ts", + "privileged_mode", + ] + Result = namedtuple("Result", ["operation_id", "new_root_ids", "new_lvl2_ids"]) + + def __init__( + self, + cg: "ChunkedGraph", + *, + user_id: str, + source_coords: Optional[Sequence[Sequence[int]]] = None, + sink_coords: Optional[Sequence[Sequence[int]]] = None, + ) -> None: + super().__init__() + self.cg = cg + self.user_id = user_id + self.source_coords = None + self.sink_coords = None + # `parent_ts` is the timestamp to get parents/roots + # after an operation fails while persisting changes. + # When that happens, parents/roots before the operation must be used to fix it. + # it is passed as an argument to `GraphEditOperation.execute()` + self.parent_ts = None + # `privileged_mode` if True, override locking. + # This is intended to be used in extremely rare cases to fix errors + # caused by failed writes. + self.privileged_mode = False + + if source_coords is not None: + self.source_coords = np.atleast_2d(source_coords).astype( + basetypes.COORDINATES + ) + if self.source_coords.size == 0: + self.source_coords = None + if sink_coords is not None: + self.sink_coords = np.atleast_2d(sink_coords).astype(basetypes.COORDINATES) + if self.sink_coords.size == 0: + self.sink_coords = None + + @classmethod + def _resolve_undo_chain( + cls, + cg: "ChunkedGraph", + *, + user_id: str, + operation_id: np.uint64, + is_undo: bool, + multicut_as_split: bool, + ): + log_record, _ = cg.client.read_log_entry(operation_id) + log_record_type = cls.get_log_record_type(log_record) + + while log_record_type in (RedoOperation, UndoOperation): + if log_record_type is RedoOperation: + operation_id = log_record[attributes.OperationLogs.RedoOperationID] + else: + is_undo = not is_undo + operation_id = log_record[attributes.OperationLogs.UndoOperationID] + log_record, _ = cg.client.read_log_entry(operation_id) + log_record_type = cls.get_log_record_type(log_record) + + if is_undo: + return UndoOperation( + cg, + user_id=user_id, + superseded_operation_id=operation_id, + multicut_as_split=multicut_as_split, + ) + else: + return RedoOperation( + cg, + user_id=user_id, + superseded_operation_id=operation_id, + multicut_as_split=multicut_as_split, + ) + + @staticmethod + def get_log_record_type( + log_record: Dict[attributes._Attribute, Union[np.ndarray, np.number]], + *, + multicut_as_split=True, + ) -> Type["GraphEditOperation"]: + """Guesses the type of GraphEditOperation given a log record dictionary. + :param log_record: log record dictionary + :type log_record: Dict[attributes._Attribute, Union[np.ndarray, np.number]] + :param multicut_as_split: If true, treat MulticutOperation as SplitOperation + + :return: The type of the matching GraphEditOperation subclass + :rtype: Type["GraphEditOperation"] + """ + if attributes.OperationLogs.UndoOperationID in log_record: + return UndoOperation + if attributes.OperationLogs.RedoOperationID in log_record: + return RedoOperation + if attributes.OperationLogs.AddedEdge in log_record: + return MergeOperation + if attributes.OperationLogs.RemovedEdge in log_record: + if ( + multicut_as_split + or attributes.OperationLogs.BoundingBoxOffset not in log_record + ): + return SplitOperation + return MulticutOperation + if attributes.OperationLogs.BoundingBoxOffset in log_record: + return MulticutOperation + raise TypeError("Could not determine graph operation type.") + + @classmethod + def from_log_record( + cls, + cg: "ChunkedGraph", + log_record: Dict[attributes._Attribute, Union[np.ndarray, np.number]], + *, + multicut_as_split: bool = True, + ) -> "GraphEditOperation": + """Generates the correct GraphEditOperation given a log record dictionary. + :param cg: The "ChunkedGraph" instance + :type cg: "ChunkedGraph" + :param log_record: log record dictionary + :type log_record: Dict[attributes._Attribute, Union[np.ndarray, np.number]] + :param multicut_as_split: If true, don't recalculate MultiCutOperation, just + use the resulting removed edges and generate SplitOperation instead (faster). + :type multicut_as_split: bool + + :return: The matching GraphEditOperation subclass + :rtype: "GraphEditOperation" + """ + + def _optional(column): + try: + return log_record[column] + except KeyError: + return None + + log_record_type = cls.get_log_record_type( + log_record, multicut_as_split=multicut_as_split + ) + user_id = log_record[attributes.OperationLogs.UserID] + + if log_record_type is UndoOperation: + superseded_operation_id = log_record[ + attributes.OperationLogs.UndoOperationID + ] + return cls.undo_operation( + cg, + user_id=user_id, + operation_id=superseded_operation_id, + multicut_as_split=multicut_as_split, + ) + + if log_record_type is RedoOperation: + superseded_operation_id = log_record[ + attributes.OperationLogs.RedoOperationID + ] + return cls.redo_operation( + cg, + user_id=user_id, + operation_id=superseded_operation_id, + multicut_as_split=multicut_as_split, + ) + + source_coords = _optional(attributes.OperationLogs.SourceCoordinate) + sink_coords = _optional(attributes.OperationLogs.SinkCoordinate) + + if log_record_type is MergeOperation: + added_edges = log_record[attributes.OperationLogs.AddedEdge] + affinities = _optional(attributes.OperationLogs.Affinity) + return MergeOperation( + cg, + user_id=user_id, + source_coords=source_coords, + sink_coords=sink_coords, + added_edges=added_edges, + affinities=affinities, + ) + + if log_record_type is SplitOperation: + removed_edges = log_record[attributes.OperationLogs.RemovedEdge] + return SplitOperation( + cg, + user_id=user_id, + source_coords=source_coords, + sink_coords=sink_coords, + removed_edges=removed_edges, + ) + + if log_record_type is MulticutOperation: + bbox_offset = log_record[attributes.OperationLogs.BoundingBoxOffset] + source_ids = log_record[attributes.OperationLogs.SourceID] + sink_ids = log_record[attributes.OperationLogs.SinkID] + removed_edges = log_record[attributes.OperationLogs.RemovedEdge] + return MulticutOperation( + cg, + user_id=user_id, + source_coords=source_coords, + sink_coords=sink_coords, + bbox_offset=bbox_offset, + source_ids=source_ids, + sink_ids=sink_ids, + removed_edges=removed_edges, + ) + + raise TypeError("Could not determine graph operation type.") + + @classmethod + def from_operation_id( + cls, + cg: "ChunkedGraph", + operation_id: np.uint64, + *, + multicut_as_split: bool = True, + privileged_mode: Optional[bool] = False, + ): + """Generates the correct GraphEditOperation given a operation ID. + :param cg: The "ChunkedGraph" instance + :type cg: "ChunkedGraph" + :param operation_id: The operation ID + :type operation_id: np.uint64 + :param multicut_as_split: If true, don't recalculate MultiCutOperation, just + use the resulting removed edges and generate SplitOperation instead (faster). + :type multicut_as_split: bool + + `privileged_mode` if True, override locking. + This is intended to be used in extremely rare cases to fix errors + caused by failed writes. + + :return: The matching GraphEditOperation subclass + :rtype: "GraphEditOperation" + """ + log, _ = cg.client.read_log_entry(operation_id) + operation = cls.from_log_record(cg, log, multicut_as_split=multicut_as_split) + operation.privileged_mode = privileged_mode + return operation + + @classmethod + def undo_operation( + cls, + cg: "ChunkedGraph", + *, + user_id: str, + operation_id: np.uint64, + multicut_as_split: bool = True, + ) -> Union["UndoOperation", "RedoOperation"]: + """Create a GraphEditOperation that, if executed, would undo the changes introduced by + operation_id. + + NOTE: If operation_id is an UndoOperation, this function might return an instance of + RedoOperation instead (depending on how the Undo/Redo chain unrolls) + + :param cg: The "ChunkedGraph" instance + :type cg: "ChunkedGraph" + :param user_id: User that should be associated with this undo operation + :type user_id: str + :param operation_id: The operation ID to be undone + :type operation_id: np.uint64 + :param multicut_as_split: If true, don't recalculate MultiCutOperation, just + use the resulting removed edges and generate SplitOperation instead (faster). + :type multicut_as_split: bool + + :return: A GraphEditOperation that, if executed, will undo the change introduced by + operation_id. + :rtype: Union["UndoOperation", "RedoOperation"] + """ + return cls._resolve_undo_chain( + cg, + user_id=user_id, + operation_id=operation_id, + is_undo=True, + multicut_as_split=multicut_as_split, + ) + + @classmethod + def redo_operation( + cls, + cg: "ChunkedGraph", + *, + user_id: str, + operation_id: np.uint64, + multicut_as_split=True, + ) -> Union["UndoOperation", "RedoOperation"]: + """Create a GraphEditOperation that, if executed, would redo the changes introduced by + operation_id. + + NOTE: If operation_id is an UndoOperation, this function might return an instance of + UndoOperation instead (depending on how the Undo/Redo chain unrolls) + + :param cg: The "ChunkedGraph" instance + :type cg: "ChunkedGraph" + :param user_id: User that should be associated with this redo operation + :type user_id: str + :param operation_id: The operation ID to be redone + :type operation_id: np.uint64 + :param multicut_as_split: If true, don't recalculate MultiCutOperation, just + use the resulting removed edges and generate SplitOperation instead (faster). + :type multicut_as_split: bool + + :return: A GraphEditOperation that, if executed, will redo the changes introduced by + operation_id. + :rtype: Union["UndoOperation", "RedoOperation"] + """ + return cls._resolve_undo_chain( + cg, + user_id=user_id, + operation_id=operation_id, + is_undo=False, + multicut_as_split=multicut_as_split, + ) + + @abstractmethod + def _update_root_ids(self) -> np.ndarray: + """Retrieves and validates the most recent root IDs affected by this GraphEditOperation. + :return: New most recent root IDs + :rtype: np.ndarray + """ + + @abstractmethod + def _apply( + self, *, operation_id, timestamp + ) -> Tuple[np.ndarray, np.ndarray, List["bigtable.row.Row"]]: + """Initiates the graph operation calculation. + :return: New root IDs, new Lvl2 node IDs, and affected records + :rtype: Tuple[np.ndarray, np.ndarray, List["bigtable.row.Row"]] + """ + + @abstractmethod + def _create_log_record( + self, + *, + operation_id, + timestamp, + operation_ts, + new_root_ids, + status=1, + exception="", + ) -> "bigtable.row.Row": + """Creates a log record with all necessary information to replay the current + GraphEditOperation + :return: Bigtable row containing the log record + :rtype: bigtable.row.Row + """ + + @abstractmethod + def invert(self) -> "GraphEditOperation": + """Creates a GraphEditOperation that would cancel out changes introduced by the current + GraphEditOperation + :return: The inverse of GraphEditOperation + :rtype: GraphEditOperation + """ + + def execute( + self, *, operation_id=None, parent_ts=None, override_ts=None + ) -> "GraphEditOperation.Result": + """ + Executes current GraphEditOperation: + * Calls the subclass's _update_root_ids method + * Locks root IDs normally + * Calls the subclass's _apply method + * Calls the subclass's _create_log_record method + * Lock roots indefinitely to prevent corruption in case persisting changes fails + Such cases are retired in a cron job. + * Persist changes + * Release indefinite locks + * Releases normal root ID lock + + `parent_ts` is the timestamp to get parents/roots + for normal edits it is None, which means latest parents/roots + But after an operation fails while persisting changes, + parents/roots before the operation must be used to fix it. + `override_ts` can be used to preserve proper timestamp in such cases. + """ + is_merge = isinstance(self, MergeOperation) + op_type = "merge" if is_merge else "split" + self.parent_ts = parent_ts + root_ids = self._update_root_ids() + with locks.RootLock( + self.cg, + root_ids, + operation_id=operation_id, + privileged_mode=self.privileged_mode, + ) as lock: + self.cg.cache = CacheService(self.cg) + self.cg.meta.custom_data["operation_id"] = operation_id + timestamp = self.cg.client.get_consolidated_lock_timestamp( + lock.locked_root_ids, + np.array([lock.operation_id] * len(lock.locked_root_ids)), + ) + + log_record_before_edit = self._create_log_record( + operation_id=lock.operation_id, + new_root_ids=types.empty_1d, + timestamp=timestamp, + operation_ts=override_ts if override_ts else timestamp, + status=attributes.OperationLogs.StatusCodes.CREATED.value, + ) + self.cg.client.write([log_record_before_edit]) + + try: + with TimeIt(f"{op_type}.apply", self.cg.graph_id, lock.operation_id): + new_root_ids, new_lvl2_ids, affected_records = self._apply( + operation_id=lock.operation_id, + timestamp=override_ts if override_ts else timestamp, + ) + if self.cg.meta.READ_ONLY: + # return without persisting changes + return GraphEditOperation.Result( + operation_id=lock.operation_id, + new_root_ids=new_root_ids, + new_lvl2_ids=new_lvl2_ids, + ) + except PreconditionError as err: + self.cg.cache = None + raise PreconditionError(err) from err + except PostconditionError as err: + self.cg.cache = None + raise PostconditionError(err) from err + except Exception as err: + # unknown exception, update log record with error + self.cg.cache = None + log_record_error = self._create_log_record( + operation_id=lock.operation_id, + new_root_ids=types.empty_1d, + timestamp=None, + operation_ts=override_ts if override_ts else timestamp, + status=attributes.OperationLogs.StatusCodes.EXCEPTION.value, + exception=repr(err), + ) + self.cg.client.write([log_record_error]) + raise Exception(err) + + with TimeIt(f"{op_type}.write", self.cg.graph_id, lock.operation_id): + result = self._write( + lock, + override_ts if override_ts else timestamp, + new_root_ids, + new_lvl2_ids, + affected_records, + ) + return result + + def _write(self, lock, timestamp, new_root_ids, new_lvl2_ids, affected_records): + """Helper to persist changes after an edit.""" + new_root_ids = np.array(new_root_ids, dtype=basetypes.NODE_ID) + new_lvl2_ids = np.array(new_lvl2_ids, dtype=basetypes.NODE_ID) + + # this must be written first to indicate writing has started. + log_record_after_edit = self._create_log_record( + operation_id=lock.operation_id, + new_root_ids=new_root_ids, + timestamp=None, + operation_ts=timestamp, + status=attributes.OperationLogs.StatusCodes.WRITE_STARTED.value, + ) + + with locks.IndefiniteRootLock( + self.cg, + lock.operation_id, + lock.locked_root_ids, + privileged_mode=lock.privileged_mode, + ): + # indefinite lock for writing, if a node instance or pod dies during this + # the roots must stay locked indefinitely to prevent further corruption. + self.cg.client.write( + [log_record_after_edit] + affected_records, + lock.locked_root_ids, + operation_id=lock.operation_id, + slow_retry=False, + ) + log_record_success = self._create_log_record( + operation_id=lock.operation_id, + new_root_ids=new_root_ids, + timestamp=None, + operation_ts=timestamp, + status=attributes.OperationLogs.StatusCodes.SUCCESS.value, + ) + self.cg.client.write([log_record_success]) + self.cg.cache = None + return GraphEditOperation.Result( + operation_id=lock.operation_id, + new_root_ids=new_root_ids, + new_lvl2_ids=new_lvl2_ids, + ) + + +class MergeOperation(GraphEditOperation): + """Merge Operation: Connect *known* pairs of supervoxels by adding a (weighted) edge. + + :param cg: The "ChunkedGraph" object + :type cg: "ChunkedGraph" + :param user_id: User ID that will be assigned to this operation + :type user_id: str + :param added_edges: Supervoxel IDs of all added edges [[source, sink]] + :type added_edges: Sequence[Sequence[np.uint64]] + :param source_coords: world space coordinates in nm, + corresponding to IDs in added_edges[:,0], defaults to None + :type source_coords: Optional[Sequence[Sequence[int]]], optional + :param sink_coords: world space coordinates in nm, + corresponding to IDs in added_edges[:,1], defaults to None + :type sink_coords: Optional[Sequence[Sequence[int]]], optional + :param affinities: edge weights for newly added edges, + entries corresponding to added_edges, defaults to None + :type affinities: Optional[Sequence[np.float32]], optional + """ + + __slots__ = [ + "source_ids", + "sink_ids", + "added_edges", + "affinities", + "bbox_offset", + "allow_same_segment_merge", + ] + + def __init__( + self, + cg: "ChunkedGraph", + *, + user_id: str, + added_edges: Sequence[Sequence[np.uint64]], + source_coords: Sequence[Sequence[int]], + sink_coords: Sequence[Sequence[int]], + bbox_offset: Tuple[int, int, int] = (240, 240, 24), + affinities: Optional[Sequence[np.float32]] = None, + allow_same_segment_merge: Optional[bool] = False, + ) -> None: + super().__init__( + cg, user_id=user_id, source_coords=source_coords, sink_coords=sink_coords + ) + self.added_edges = np.atleast_2d(added_edges).astype(basetypes.NODE_ID) + self.bbox_offset = np.atleast_1d(bbox_offset).astype(basetypes.COORDINATES) + self.allow_same_segment_merge = allow_same_segment_merge + + self.affinities = None + if affinities is not None: + self.affinities = np.atleast_1d(affinities).astype(basetypes.EDGE_AFFINITY) + if self.affinities.size == 0: + self.affinities = None + + if np.any(np.equal(self.added_edges[:, 0], self.added_edges[:, 1])): + raise PreconditionError("Requested merge contains at least 1 self-loop.") + + layers = self.cg.get_chunk_layers(self.added_edges.ravel()) + assert np.sum(layers) == layers.size, "Supervoxels expected." + + def _update_root_ids(self) -> np.ndarray: + root_ids = np.unique( + self.cg.get_roots( + self.added_edges.ravel(), assert_roots=True, time_stamp=self.parent_ts + ) + ) + return root_ids + + def _apply( + self, *, operation_id, timestamp + ) -> Tuple[np.ndarray, np.ndarray, List["bigtable.row.Row"]]: + root_ids = set( + self.cg.get_roots( + self.added_edges.ravel(), assert_roots=True, time_stamp=self.parent_ts + ) + ) + if len(root_ids) < 2 and not self.allow_same_segment_merge: + raise PreconditionError("Supervoxels must belong to different objects.") + bbox = get_bbox(self.source_coords, self.sink_coords, self.bbox_offset) + with TimeIt("subgraph", self.cg.graph_id, operation_id): + edges = self.cg.get_subgraph( + root_ids, + bbox=bbox, + bbox_is_coordinate=True, + edges_only=True, + ) + + with TimeIt("preprocess", self.cg.graph_id, operation_id): + inactive_edges = edits.merge_preprocess( + self.cg, + subgraph_edges=edges, + supervoxels=self.added_edges.ravel(), + parent_ts=self.parent_ts, + ) + + atomic_edges, fake_edge_rows = edits.check_fake_edges( + self.cg, + atomic_edges=self.added_edges, + inactive_edges=inactive_edges, + time_stamp=timestamp, + parent_ts=self.parent_ts, + ) + with TimeIt("add_edges", self.cg.graph_id, operation_id): + new_roots, new_l2_ids, new_entries = edits.add_edges( + self.cg, + atomic_edges=atomic_edges, + operation_id=operation_id, + time_stamp=timestamp, + parent_ts=self.parent_ts, + ) + return new_roots, new_l2_ids, fake_edge_rows + new_entries + + def _create_log_record( + self, + *, + operation_id: np.uint64, + timestamp: datetime, + operation_ts: datetime, + new_root_ids: Sequence[np.uint64], + status: int = 1, + exception: str = "", + ) -> "bigtable.row.Row": + val_dict = { + attributes.OperationLogs.UserID: self.user_id, + attributes.OperationLogs.RootID: new_root_ids, + attributes.OperationLogs.AddedEdge: self.added_edges, + attributes.OperationLogs.Status: status, + attributes.OperationLogs.OperationException: exception, + attributes.OperationLogs.OperationTimeStamp: operation_ts, + } + if self.source_coords is not None: + val_dict[attributes.OperationLogs.SourceCoordinate] = self.source_coords + if self.sink_coords is not None: + val_dict[attributes.OperationLogs.SinkCoordinate] = self.sink_coords + if self.affinities is not None: + val_dict[attributes.OperationLogs.Affinity] = self.affinities + return self.cg.client.mutate_row( + serializers.serialize_uint64(operation_id), val_dict, timestamp + ) + + def invert(self) -> "SplitOperation": + return SplitOperation( + self.cg, + user_id=self.user_id, + removed_edges=self.added_edges, + source_coords=self.source_coords, + sink_coords=self.sink_coords, + ) + + +class SplitOperation(GraphEditOperation): + """Split Operation: Cut *known* pairs of supervoxel that are directly connected by an edge. + + :param cg: The "ChunkedGraph" object + :type cg: "ChunkedGraph" + :param user_id: User ID that will be assigned to this operation + :type user_id: str + :param removed_edges: Supervoxel IDs of all removed edges [[source, sink]] + :type removed_edges: Sequence[Sequence[np.uint64]] + :param source_coords: world space coordinates in nm, corresponding to IDs in + removed_edges[:,0], defaults to None + :type source_coords: Optional[Sequence[Sequence[int]]], optional + :param sink_coords: world space coordinates in nm, corresponding to IDs in + removed_edges[:,1], defaults to None + :type sink_coords: Optional[Sequence[Sequence[int]]], optional + """ + + __slots__ = ["removed_edges", "bbox_offset"] + + def __init__( + self, + cg: "ChunkedGraph", + *, + user_id: str, + removed_edges: Sequence[Sequence[np.uint64]], + source_coords: Optional[Sequence[Sequence[int]]] = None, + sink_coords: Optional[Sequence[Sequence[int]]] = None, + bbox_offset: Tuple[int] = (240, 240, 24), + ) -> None: + super().__init__( + cg, user_id=user_id, source_coords=source_coords, sink_coords=sink_coords + ) + self.removed_edges = np.atleast_2d(removed_edges).astype(basetypes.NODE_ID) + self.bbox_offset = np.atleast_1d(bbox_offset).astype(basetypes.COORDINATES) + if np.any(np.equal(self.removed_edges[:, 0], self.removed_edges[:, 1])): + raise PreconditionError("Requested split contains at least 1 self-loop.") + + layers = self.cg.get_chunk_layers(self.removed_edges.ravel()) + assert np.sum(layers) == layers.size, "IDs must be supervoxels." + + def _update_root_ids(self) -> np.ndarray: + root_ids = np.unique( + self.cg.get_roots( + self.removed_edges.ravel(), + assert_roots=True, + time_stamp=self.parent_ts, + ) + ) + if len(root_ids) > 1: + raise PreconditionError("Supervoxels must belong to the same object.") + return root_ids + + def _apply( + self, *, operation_id, timestamp + ) -> Tuple[np.ndarray, np.ndarray, List["bigtable.row.Row"]]: + if ( + len( + set( + self.cg.get_roots( + self.removed_edges.ravel(), + assert_roots=True, + time_stamp=self.parent_ts, + ) + ) + ) + > 1 + ): + raise PreconditionError("Supervoxels must belong to the same object.") + + with TimeIt("subgraph", self.cg.graph_id, operation_id): + l2id_agglomeration_d, _ = self.cg.get_l2_agglomerations( + self.cg.get_parents( + self.removed_edges.ravel(), time_stamp=self.parent_ts + ), + ) + with TimeIt("remove_edges", self.cg.graph_id, operation_id): + return edits.remove_edges( + self.cg, + operation_id=operation_id, + atomic_edges=self.removed_edges, + l2id_agglomeration_d=l2id_agglomeration_d, + time_stamp=timestamp, + parent_ts=self.parent_ts, + ) + + def _create_log_record( + self, + *, + operation_id: np.uint64, + timestamp: datetime, + operation_ts: datetime, + new_root_ids: Sequence[np.uint64], + status: int = 1, + exception: str = "", + ) -> "bigtable.row.Row": + val_dict = { + attributes.OperationLogs.UserID: self.user_id, + attributes.OperationLogs.RootID: new_root_ids, + attributes.OperationLogs.RemovedEdge: self.removed_edges, + attributes.OperationLogs.Status: status, + attributes.OperationLogs.OperationException: exception, + attributes.OperationLogs.OperationTimeStamp: operation_ts, + } + if self.source_coords is not None: + val_dict[attributes.OperationLogs.SourceCoordinate] = self.source_coords + if self.sink_coords is not None: + val_dict[attributes.OperationLogs.SinkCoordinate] = self.sink_coords + + return self.cg.client.mutate_row( + serializers.serialize_uint64(operation_id), val_dict, timestamp + ) + + def invert(self) -> "MergeOperation": + return MergeOperation( + self.cg, + user_id=self.user_id, + added_edges=self.removed_edges, + source_coords=self.source_coords, + sink_coords=self.sink_coords, + ) + + +class MulticutOperation(GraphEditOperation): + """ + Multicut Operation: Apply min-cut algorithm to identify suitable edges for removal + in order to separate two groups of supervoxels. + + :param cg: The "ChunkedGraph" object + :type cg: "ChunkedGraph" + :param user_id: User ID that will be assigned to this operation + :type user_id: str + :param source_ids: Supervoxel IDs that should be separated from supervoxel IDs in sink_ids + :type souce_ids: Sequence[np.uint64] + :param sink_ids: Supervoxel IDs that should be separated from supervoxel IDs in source_ids + :type sink_ids: Sequence[np.uint64] + :param source_coords: world space coordinates in nm, corresponding to IDs in source_ids + :type source_coords: Sequence[Sequence[int]] + :param sink_coords: world space coordinates in nm, corresponding to IDs in sink_ids + :type sink_coords: Sequence[Sequence[int]] + :param bbox_offset: Padding for min-cut bounding box, applied to min/max coordinates + retrieved from source_coords and sink_coords, defaults to None + :type bbox_offset: Sequence[int] + """ + + __slots__ = [ + "source_ids", + "sink_ids", + "removed_edges", + "bbox_offset", + "path_augment", + "disallow_isolating_cut", + ] + + def __init__( + self, + cg: "ChunkedGraph", + *, + user_id: str, + source_ids: Sequence[np.uint64], + sink_ids: Sequence[np.uint64], + source_coords: Sequence[Sequence[int]], + sink_coords: Sequence[Sequence[int]], + bbox_offset: Sequence[int], + removed_edges: Sequence[Sequence[np.uint64]] = types.empty_2d, + path_augment: bool = True, + disallow_isolating_cut: bool = True, + ) -> None: + super().__init__( + cg, user_id=user_id, source_coords=source_coords, sink_coords=sink_coords + ) + self.removed_edges = removed_edges + self.source_ids = np.atleast_1d(source_ids).astype(basetypes.NODE_ID) + self.sink_ids = np.atleast_1d(sink_ids).astype(basetypes.NODE_ID) + self.bbox_offset = np.atleast_1d(bbox_offset).astype(basetypes.COORDINATES) + self.path_augment = path_augment + self.disallow_isolating_cut = disallow_isolating_cut + if np.any(np.in1d(self.sink_ids, self.source_ids)): + raise PreconditionError( + "Supervoxels exist in both sink and source, " + "try placing the points further apart." + ) + + ids = np.concatenate([self.source_ids, self.sink_ids]) + layers = self.cg.get_chunk_layers(ids) + assert np.sum(layers) == layers.size, "IDs must be supervoxels." + + def _update_root_ids(self) -> np.ndarray: + sink_and_source_ids = np.concatenate((self.source_ids, self.sink_ids)) + root_ids = np.unique( + self.cg.get_roots( + sink_and_source_ids, assert_roots=True, time_stamp=self.parent_ts + ) + ) + if len(root_ids) > 1: + raise PreconditionError("Supervoxels must belong to the same segment.") + return root_ids + + def _apply( + self, *, operation_id, timestamp + ) -> Tuple[np.ndarray, np.ndarray, List["bigtable.row.Row"]]: + # Verify that sink and source are from the same root object + root_ids = set( + self.cg.get_roots( + np.concatenate([self.source_ids, self.sink_ids]), + assert_roots=True, + time_stamp=self.parent_ts, + ) + ) + if len(root_ids) > 1: + raise PreconditionError("Supervoxels must belong to the same object.") + + bbox = get_bbox( + self.source_coords, + self.sink_coords, + self.cg.meta.split_bounding_offset, + ) + with TimeIt("get_subgraph", self.cg.graph_id, operation_id): + l2id_agglomeration_d, edges = self.cg.get_subgraph( + root_ids.pop(), bbox=bbox, bbox_is_coordinate=True + ) + + edges = reduce(lambda x, y: x + y, edges, Edges([], [])) + supervoxels = np.concatenate( + [agg.supervoxels for agg in l2id_agglomeration_d.values()] + ) + mask0 = np.in1d(edges.node_ids1, supervoxels) + mask1 = np.in1d(edges.node_ids2, supervoxels) + edges = edges[mask0 & mask1] + if len(edges) == 0: + raise PreconditionError("No local edges found.") + + with TimeIt("multicut", self.cg.graph_id, operation_id): + self.removed_edges = run_multicut( + edges, + self.source_ids, + self.sink_ids, + path_augment=self.path_augment, + disallow_isolating_cut=self.disallow_isolating_cut, + ) + if not self.removed_edges.size: + raise PostconditionError("Mincut could not find any edges to remove.") + + with TimeIt("remove_edges", self.cg.graph_id, operation_id): + return edits.remove_edges( + self.cg, + operation_id=operation_id, + atomic_edges=self.removed_edges, + l2id_agglomeration_d=l2id_agglomeration_d, + time_stamp=timestamp, + parent_ts=self.parent_ts, + ) + + def _create_log_record( + self, + *, + operation_id: np.uint64, + timestamp: datetime, + operation_ts: datetime, + new_root_ids: Sequence[np.uint64], + status: int = 1, + exception: str = "", + ) -> "bigtable.row.Row": + val_dict = { + attributes.OperationLogs.UserID: self.user_id, + attributes.OperationLogs.RootID: new_root_ids, + attributes.OperationLogs.SourceCoordinate: self.source_coords, + attributes.OperationLogs.SinkCoordinate: self.sink_coords, + attributes.OperationLogs.SourceID: self.source_ids, + attributes.OperationLogs.SinkID: self.sink_ids, + attributes.OperationLogs.BoundingBoxOffset: self.bbox_offset, + attributes.OperationLogs.RemovedEdge: self.removed_edges, + attributes.OperationLogs.Status: status, + attributes.OperationLogs.OperationException: exception, + attributes.OperationLogs.OperationTimeStamp: operation_ts, + } + return self.cg.client.mutate_row( + serializers.serialize_uint64(operation_id), val_dict, timestamp + ) + + def invert(self) -> "MergeOperation": + return MergeOperation( + self.cg, + user_id=self.user_id, + added_edges=self.removed_edges, + source_coords=self.source_coords, + sink_coords=self.sink_coords, + ) + + +class RedoOperation(GraphEditOperation): + """ + RedoOperation: Used to apply a previous graph edit operation. In contrast to a + "coincidental" redo (e.g. merging an edge added by a previous merge operation), a + RedoOperation is linked to an earlier operation ID to enable its correct repetition. + Acts as counterpart to UndoOperation. + + NOTE: Avoid instantiating a RedoOperation directly, if possible. The class method + GraphEditOperation.redo_operation() is in general preferred as it will correctly + unroll Undo/Redo chains. + + :param cg: The "ChunkedGraph" object + :type cg: "ChunkedGraph" + :param user_id: User ID that will be assigned to this operation + :type user_id: str + :param superseded_operation_id: Operation ID to be redone + :type superseded_operation_id: np.uint64 + :param multicut_as_split: If true, don't recalculate MultiCutOperation, just + use the resulting removed edges and generate SplitOperation instead (faster). + :type multicut_as_split: bool + """ + + __slots__ = [ + "superseded_operation_id", + "superseded_operation", + "added_edges", + "removed_edges", + "operation_status", + ] + + def __init__( + self, + cg: "ChunkedGraph", + *, + user_id: str, + superseded_operation_id: np.uint64, + multicut_as_split: bool, + ) -> None: + super().__init__(cg, user_id=user_id) + log_record, _ = cg.client.read_log_entry(superseded_operation_id) + log_record_type = GraphEditOperation.get_log_record_type(log_record) + if log_record_type in (RedoOperation, UndoOperation): + raise ValueError( + ( + f"RedoOperation received {log_record_type.__name__} as target operation, " + "which is not allowed. Use GraphEditOperation.redo_operation() instead." + ) + ) + + self.superseded_operation_id = superseded_operation_id + self.operation_status = log_record[attributes.OperationLogs.Status] + if self.operation_status != attributes.OperationLogs.StatusCodes.SUCCESS.value: + return + self.superseded_operation = GraphEditOperation.from_log_record( + cg, log_record=log_record, multicut_as_split=multicut_as_split + ) + if hasattr(self.superseded_operation, "added_edges"): + self.added_edges = self.superseded_operation.added_edges + if hasattr(self.superseded_operation, "removed_edges"): + self.removed_edges = self.superseded_operation.removed_edges + + def _update_root_ids(self): + if self.operation_status != attributes.OperationLogs.StatusCodes.SUCCESS.value: + return types.empty_1d + return self.superseded_operation._update_root_ids() + + def _apply( + self, *, operation_id, timestamp + ) -> Tuple[np.ndarray, np.ndarray, List["bigtable.row.Row"]]: + return self.superseded_operation._apply( + operation_id=operation_id, timestamp=timestamp + ) + + def _create_log_record( + self, + *, + operation_id: np.uint64, + timestamp: datetime, + operation_ts: datetime, + new_root_ids: Sequence[np.uint64], + status: int = 1, + exception: str = "", + ) -> "bigtable.row.Row": + val_dict = { + attributes.OperationLogs.UserID: self.user_id, + attributes.OperationLogs.RedoOperationID: self.superseded_operation_id, + attributes.OperationLogs.RootID: new_root_ids, + attributes.OperationLogs.OperationTimeStamp: operation_ts, + attributes.OperationLogs.Status: status, + attributes.OperationLogs.OperationException: exception, + } + if hasattr(self, "added_edges"): + val_dict[attributes.OperationLogs.AddedEdge] = self.added_edges + if hasattr(self, "removed_edges"): + val_dict[attributes.OperationLogs.RemovedEdge] = self.removed_edges + return self.cg.client.mutate_row( + serializers.serialize_uint64(operation_id), val_dict, timestamp + ) + + def invert(self) -> "GraphEditOperation": + """ + Inverts a RedoOperation. Treated as Undoing the original operation + """ + return UndoOperation( + self.cg, + user_id=self.user_id, + superseded_operation_id=self.superseded_operation_id, + multicut_as_split=True, + ) + + def execute( + self, *, operation_id=None, parent_ts=None, override_ts=None + ) -> "GraphEditOperation.Result": + if self.operation_status != attributes.OperationLogs.StatusCodes.SUCCESS.value: + # Don't redo failed operations + return GraphEditOperation.Result( + operation_id=operation_id, + new_root_ids=types.empty_1d, + new_lvl2_ids=types.empty_1d, + ) + return super().execute( + operation_id=operation_id, parent_ts=parent_ts, override_ts=override_ts + ) + + +class UndoOperation(GraphEditOperation): + """ + UndoOperation: Used to apply the inverse of a previous graph edit operation. In contrast + to a "coincidental" undo (e.g. merging an edge previously removed by a split operation), an + UndoOperation is linked to an earlier operation ID to enable its correct reversal. + + NOTE: Avoid instantiating an UndoOperation directly, if possible. The class method + GraphEditOperation.undo_operation() is in general preferred as it will correctly + unroll Undo/Redo chains. + + :param cg: The "ChunkedGraph" object + :type cg: "ChunkedGraph" + :param user_id: User ID that will be assigned to this operation + :type user_id: str + :param superseded_operation_id: Operation ID to be undone + :type superseded_operation_id: np.uint64 + :param multicut_as_split: If true, don't recalculate MultiCutOperation, just + use the resulting removed edges and generate SplitOperation instead (faster). + :type multicut_as_split: bool + """ + + __slots__ = [ + "superseded_operation_id", + "inverse_superseded_operation", + "added_edges", + "removed_edges", + "operation_status", + ] + + def __init__( + self, + cg: "ChunkedGraph", + *, + user_id: str, + superseded_operation_id: np.uint64, + multicut_as_split: bool, + ) -> None: + super().__init__(cg, user_id=user_id) + log_record, _ = cg.client.read_log_entry(superseded_operation_id) + log_record_type = GraphEditOperation.get_log_record_type(log_record) + if log_record_type in (RedoOperation, UndoOperation): + raise ValueError( + ( + f"UndoOperation received {log_record_type.__name__} as target operation, " + "which is not allowed. Use GraphEditOperation.undo_operation() instead." + ) + ) + + self.superseded_operation_id = superseded_operation_id + self.operation_status = log_record[attributes.OperationLogs.Status] + if self.operation_status != attributes.OperationLogs.StatusCodes.SUCCESS.value: + return + superseded_operation = GraphEditOperation.from_log_record( + cg, log_record=log_record, multicut_as_split=multicut_as_split + ) + if log_record_type is MergeOperation: + # account for additional activated edges so merge can be properly undone + from .misc import get_activated_edges + + activated_edges = get_activated_edges(cg, superseded_operation_id) + if len(activated_edges) > 0: + superseded_operation.added_edges = activated_edges + self.inverse_superseded_operation = superseded_operation.invert() + if hasattr(self.inverse_superseded_operation, "added_edges"): + self.added_edges = self.inverse_superseded_operation.added_edges + if hasattr(self.inverse_superseded_operation, "removed_edges"): + self.removed_edges = self.inverse_superseded_operation.removed_edges + + def _update_root_ids(self): + if self.operation_status != attributes.OperationLogs.StatusCodes.SUCCESS.value: + return types.empty_1d + return self.inverse_superseded_operation._update_root_ids() + + def _apply( + self, *, operation_id, timestamp + ) -> Tuple[np.ndarray, np.ndarray, List["bigtable.row.Row"]]: + if isinstance(self.inverse_superseded_operation, MergeOperation): + return edits.add_edges( + self.inverse_superseded_operation.cg, + atomic_edges=self.inverse_superseded_operation.added_edges, + operation_id=operation_id, + time_stamp=timestamp, + parent_ts=self.inverse_superseded_operation.parent_ts, + allow_same_segment_merge=True, + ) + return self.inverse_superseded_operation._apply( + operation_id=operation_id, timestamp=timestamp + ) + + def _create_log_record( + self, + *, + operation_id: np.uint64, + timestamp: datetime, + operation_ts: datetime, + new_root_ids: Sequence[np.uint64], + status: int = 1, + exception: str = "", + ) -> "bigtable.row.Row": + val_dict = { + attributes.OperationLogs.UserID: self.user_id, + attributes.OperationLogs.UndoOperationID: self.superseded_operation_id, + attributes.OperationLogs.RootID: new_root_ids, + attributes.OperationLogs.OperationTimeStamp: operation_ts, + attributes.OperationLogs.Status: status, + attributes.OperationLogs.OperationException: exception, + } + if hasattr(self, "added_edges"): + val_dict[attributes.OperationLogs.AddedEdge] = self.added_edges + if hasattr(self, "removed_edges"): + val_dict[attributes.OperationLogs.RemovedEdge] = self.removed_edges + return self.cg.client.mutate_row( + serializers.serialize_uint64(operation_id), val_dict, timestamp + ) + + def invert(self) -> "GraphEditOperation": + """ + Inverts an UndoOperation. Treated as Redoing the original operation + """ + return RedoOperation( + self.cg, + user_id=self.user_id, + superseded_operation_id=self.superseded_operation_id, + multicut_as_split=True, + ) + + def execute( + self, *, operation_id=None, parent_ts=None, override_ts=None + ) -> "GraphEditOperation.Result": + if self.operation_status != attributes.OperationLogs.StatusCodes.SUCCESS.value: + # Don't undo failed operations + return GraphEditOperation.Result( + operation_id=operation_id, + new_root_ids=types.empty_1d, + new_lvl2_ids=types.empty_1d, + ) + if isinstance(self.inverse_superseded_operation, MergeOperation): + # in case we are undoing a partial split (with only one resulting root id) + e, a = get_edges_status( + self.inverse_superseded_operation.cg, + self.inverse_superseded_operation.added_edges, + ) + if np.any(~e): + raise PreconditionError("All edges must exist.") + if np.all(a): + return GraphEditOperation.Result( + operation_id=operation_id, + new_root_ids=types.empty_1d, + new_lvl2_ids=types.empty_1d, + ) + if isinstance(self.inverse_superseded_operation, SplitOperation): + e, a = get_edges_status( + self.inverse_superseded_operation.cg, + self.inverse_superseded_operation.removed_edges, + ) + if np.any(~e): + raise PreconditionError("All edges must exist.") + if np.all(~a): + return GraphEditOperation.Result( + operation_id=operation_id, + new_root_ids=types.empty_1d, + new_lvl2_ids=types.empty_1d, + ) + return super().execute( + operation_id=operation_id, parent_ts=parent_ts, override_ts=override_ts + ) diff --git a/pychunkedgraph/graph/segmenthistory.py b/pychunkedgraph/graph/segmenthistory.py new file mode 100644 index 000000000..30f42d15b --- /dev/null +++ b/pychunkedgraph/graph/segmenthistory.py @@ -0,0 +1,449 @@ +import collections +from datetime import datetime +from typing import Iterable + +import numpy as np +import fastremap +from networkx.algorithms.dag import ancestors as nx_ancestors + +from .attributes import OperationLogs +from .utils import basetypes + + +class SegmentHistory: + def __init__( + self, + cg, + root_ids, + timestamp_past: datetime = None, + timestamp_future: datetime = None, + ): + self.cg = cg + if isinstance(root_ids, Iterable): + self.root_ids = np.array(root_ids) + else: + self.root_ids = np.array([root_ids]) + + for root_id in self.root_ids: + assert cg.is_root(root_id), f"{root_id} is no root" + + self.timestamp_past = cg.get_earliest_timestamp() + if timestamp_past is not None: + self.timestamp_past = timestamp_past + + self.timestamp_future = datetime.utcnow() + if timestamp_future is None: + self.timestamp_future = timestamp_future + + self._lineage_graph = None + self._operation_id_root_id_dict = None + self._root_id_operation_id_dict = None + self._root_id_timestamp_dict = None + self._log_rows_cache = None + self._tabular_changelogs = None + + @property + def lineage_graph(self): + from .lineage import lineage_graph + + if self._lineage_graph is None: + self._lineage_graph = lineage_graph( + self.cg, self.root_ids, self.timestamp_past, self.timestamp_future + ) + return self._lineage_graph + + @property + def root_id_operation_id_dict(self): + if self._root_id_operation_id_dict is None: + self._root_id_operation_id_dict = dict( + self.lineage_graph.nodes.data("operation_id", default=0) + ) + return self._root_id_operation_id_dict + + @property + def root_id_timestamp_dict(self): + if self._root_id_timestamp_dict is None: + self._root_id_timestamp_dict = dict( + self.lineage_graph.nodes.data("timestamp", default=0) + ) + return self._root_id_timestamp_dict + + @property + def operation_id_root_id_dict(self): + if self._operation_id_root_id_dict is None: + self._operation_id_root_id_dict = collections.defaultdict(list) + for root_id, operation_id in self.root_id_operation_id_dict.items(): + self._operation_id_root_id_dict[operation_id].append(root_id) + return self._operation_id_root_id_dict + + @property + def operation_ids(self): + return np.array(list(self.operation_id_root_id_dict.keys())) + + @property + def _log_rows(self): + if self._log_rows_cache is None: + self._log_rows_cache = self.cg.client.read_log_entries(self.operation_ids) + return self._log_rows_cache + + @property + def tabular_changelogs(self): + if self._tabular_changelogs is None: + self._tabular_changelogs = self._build_tabular_changelogs() + return self._tabular_changelogs + + @property + def tabular_changelogs_filtered(self): + filtered_tabular_changelogs = {} + for root_id in self.root_ids: + filtered_tabular_changelogs[root_id] = self.tabular_changelog( + root_id=root_id, filtered=True + ) + return filtered_tabular_changelogs + + def collect_edited_sv_ids(self, root_id=None): + if root_id is None: + operation_ids = self.past_operation_ids() + else: + assert root_id in self.root_ids + operation_ids = self.past_operation_ids(root_id=root_id) + + edited_sv_ids = [] + for operation_id in operation_ids: + edited_sv_ids.extend(self.log_entry(operation_id).edges_failsafe) + + if len(edited_sv_ids) > 0: + return fastremap.unique(np.array(edited_sv_ids)) + else: + return np.empty((0), dtype=np.uint64) + + def _build_tabular_changelogs(self): + from pandas import DataFrame + + tabular_changelogs = {} + all_user_ids = [] + root_lookup = lambda sv_ids, ts: dict( + zip( + sv_ids, + self.cg.get_roots(sv_ids, time_stamp=ts), + ) + ).get + + earliest_ts = self.cg.get_earliest_timestamp() + root_ts_d = dict( + zip( + self.root_ids, + self.cg.get_node_timestamps(self.root_ids, return_numpy=False), + ) + ) + for root_id in self.root_ids: + root_ts = root_ts_d[root_id] + edited_sv_ids = self.collect_edited_sv_ids(root_id=root_id) + current_root_id_lookup_vec = np.vectorize( + root_lookup(edited_sv_ids, root_ts) + ) + original_root_id_lookup_vec = np.vectorize( + root_lookup(edited_sv_ids, earliest_ts) + ) + + is_merge_list = [] + is_in_neuron_list = [] + is_relevant_list = [] + timestamp_list = [] + user_list = [] + before_root_ids_list = [] + after_root_ids_list = [] + + operation_ids = self.past_operation_ids(root_id=root_id) + sorted_operation_ids = np.sort(operation_ids) + for operation_id in sorted_operation_ids: + entry = self.log_entry(operation_id) + is_merge_list.append(entry.is_merge) + timestamp_list.append(entry.timestamp) + user_list.append(entry.user_id) + + sv_ids_original_root = original_root_id_lookup_vec(entry.edges_failsafe) + sv_ids_current_root = current_root_id_lookup_vec(entry.edges_failsafe) + before_ids = list(self.operation_id_root_id_dict[operation_id]) + after_root_ids_list.append( + list(self.lineage_graph.neighbors(before_ids[0])) + ) + before_root_ids_list.append(before_ids) + if entry.is_merge: + is_relevant_list.append(len(np.unique(sv_ids_original_root)) != 1) + is_in_neuron_list.append(np.all(sv_ids_current_root == root_id)) + else: + is_relevant_list.append(len(np.unique(sv_ids_current_root)) != 1) + is_in_neuron_list.append(np.any(sv_ids_current_root == root_id)) + + all_user_ids.extend(user_list) + tabular_changelogs[root_id] = DataFrame.from_dict( + { + "operation_id": sorted_operation_ids, + "timestamp": timestamp_list, + "user_id": user_list, + "before_root_ids": before_root_ids_list, + "after_root_ids": after_root_ids_list, + "is_merge": is_merge_list, + "in_neuron": is_in_neuron_list, + "is_relevant": is_relevant_list, + } + ) + return tabular_changelogs + + def log_entry(self, operation_id): + ts = self._log_rows[operation_id]["timestamp"] + return LogEntry(self._log_rows[operation_id], timestamp=ts) + + def change_log_summary(self, root_id=None, filtered=False): + if root_id is None: + root_ids = self.root_ids + else: + assert root_id in self.root_ids + root_ids = [root_id] + + for root_id in root_ids: + tabular_changelog = self.tabular_changelog(root_id, filtered=filtered) + + user_ids = np.array(tabular_changelog[["user_id"]]).reshape(-1) + u_user_ids = np.unique(user_ids) + + n_splits = 0 + n_mergers = 0 + n_edits = 0 + user_dict = collections.defaultdict(collections.Counter) + for user_id in u_user_ids: + m = user_ids == user_id + n_user_edits = np.sum(m) + n_user_mergers = int(np.sum(tabular_changelog[["is_merge"]][m])) + n_user_splits = n_user_edits - n_user_mergers + + user_dict[user_id]["n_splits"] += n_user_splits + user_dict[user_id]["n_mergers"] += n_user_mergers + n_splits += n_user_splits + n_mergers += n_user_mergers + n_edits += n_user_edits + + before_col = list(tabular_changelog["before_root_ids"]) + if len(before_col) == 0: + past_ids = np.empty((0), dtype=basetypes.NODE_ID) + else: + past_ids = np.concatenate(before_col, dtype=basetypes.NODE_ID) + + operation_ids = np.array( + tabular_changelog["operation_id"], dtype=basetypes.NODE_ID + ) + + return { + "n_splits": n_splits, + "n_mergers": n_mergers, + "user_info": user_dict, + "operations_ids": operation_ids, + "past_ids": past_ids, + } + + def merge_log(self, root_id=None, correct_for_wrong_coord_type=True): + if root_id is None: + root_ids = self.root_ids + else: + assert root_id in self.root_ids + root_ids = [root_id] + + added_edges = [] + added_edge_coords = [] + + for root_id in root_ids: + for operation_id in self.past_operation_ids(root_id=root_id): + log_entry = self.log_entry(operation_id) + if not log_entry.is_merge: + continue + + added_edges.append(log_entry.added_edges) + coords = log_entry.coordinates + if correct_for_wrong_coord_type: + # A little hack because we got the datatype wrong... + coords = [np.frombuffer(coords[0]), np.frombuffer(coords[1])] + coords *= self.cg.meta.resolution + added_edge_coords.append(coords) + return {"merge_edges": added_edges, "merge_edge_coords": added_edge_coords} + + def past_operation_ids(self, root_id=None): + if root_id is None: + root_ids = self.root_ids + else: + assert root_id in self.root_ids + root_ids = [root_id] + + ancs = [] + for root_id in root_ids: + ancs.extend(nx_ancestors(self.lineage_graph, root_id)) + + if len(ancs) == 0: + return np.array([], dtype=int) + + ancs = fastremap.unique(np.array(ancs, dtype=np.uint64)) + operation_ids = [] + for anc in ancs: + operation_ids.append(self.root_id_operation_id_dict.get(anc, 0)) + + operation_ids = np.array(operation_ids) + operation_ids = fastremap.unique(operation_ids) + operation_ids = operation_ids[operation_ids != 0] + return operation_ids + + def tabular_changelog(self, root_id=None, filtered=False): + if len(self.root_ids) == 1: + root_id = self.root_ids[0] + else: + assert root_id is not None + + tabular_changelog = self.tabular_changelogs[root_id].copy() + if filtered: + in_neuron = np.array(tabular_changelog[["in_neuron"]]) + is_relevant = np.array(tabular_changelog[["is_relevant"]]) + inclusion_mask = np.logical_and(in_neuron, is_relevant).reshape(-1) + + tabular_changelog = tabular_changelog[inclusion_mask] + tabular_changelog = tabular_changelog.drop("in_neuron", axis=1) + tabular_changelog = tabular_changelog.drop("is_relevant", axis=1) + return tabular_changelog + + def last_edit_timestamp(self, root_id=None): + assert root_id in self.root_ids + return self.root_id_timestamp_dict[root_id] + + def past_future_id_mapping(self, root_id=None): + from networkx.algorithms.dag import descendants as nx_descendants + + if root_id is None: + root_ids = self.root_ids + else: + assert root_id in self.root_ids + root_ids = [root_id] + + in_degree_dict = dict(self.lineage_graph.in_degree) + out_degree_dict = dict(self.lineage_graph.out_degree) + in_degree_dict_vec = np.vectorize(in_degree_dict.get) + out_degree_dict_vec = np.vectorize(out_degree_dict.get) + past_id_mapping = {} + future_id_mapping = {} + for root_id in root_ids: + ancestors = np.array(list(nx_ancestors(self.lineage_graph, root_id))) + if len(ancestors) == 0: + past_id_mapping[int(root_id)] = [root_id] + else: + anc_in_degrees = in_degree_dict_vec(ancestors) + past_id_mapping[int(root_id)] = ancestors[anc_in_degrees == 0] + + past_ids = fastremap.unique(np.concatenate(list(past_id_mapping.values()))) + for past_id in past_ids: + descendants = np.array( + list(nx_descendants(self.lineage_graph, past_id)) + [past_id], + dtype=np.uint64, + ) + if len(descendants) == 0: + future_id_mapping[past_id] = past_id + continue + + out_degrees = out_degree_dict_vec(descendants) + if 2 in out_degrees or np.sum(out_degrees == 0) > 1: + continue + + single_degree_descendants = descendants[ + out_degree_dict_vec(descendants) == 1 + ] + if len(single_degree_descendants) == 0: + future_id_mapping[past_id] = descendants[out_degrees == 0][0] + continue + + partner_in_degrees = in_degree_dict_vec( + [ + list(self.lineage_graph.neighbors(d))[0] + for d in single_degree_descendants + ] + ) + if 1 in partner_in_degrees: + continue + future_id_mapping[past_id] = descendants[out_degrees == 0][0] + return past_id_mapping, future_id_mapping + + +class LogEntry: + def __init__(self, row, timestamp): + self.row = row + self.timestamp = timestamp + + @property + def is_merge(self): + return OperationLogs.AddedEdge in self.row + + @property + def user_id(self): + return self.row[OperationLogs.UserID] + + @property + def log_type(self): + return "merge" if self.is_merge else "split" + + @property + def root_ids(self): + return self.row[OperationLogs.RootID] + + @property + def edges_failsafe(self): + try: + return np.array(self.sink_source_ids) + except: + if self.is_merge: + return np.array(self.added_edges).flatten() + if not self.is_merge: + return np.array(self.removed_edges).flatten() + + @property + def sink_source_ids(self): + return np.concatenate( + [ + self.row[OperationLogs.SinkID], + self.row[OperationLogs.SourceID], + ] + ) + + @property + def added_edges(self): + assert self.is_merge, "Not a merge operation." + return self.row[OperationLogs.AddedEdge] + + @property + def removed_edges(self): + assert not self.is_merge, "Not a split operation." + return self.row[OperationLogs.RemovedEdge] + + @property + def coordinates(self): + return np.array( + [ + self.row[OperationLogs.SourceCoordinate], + self.row[OperationLogs.SinkCoordinate], + ] + ) + + def __iter__(self): + attrs = [self.user_id, self.log_type, self.root_ids, self.timestamp] + for attr in attrs: + yield attr + + def __str__(self): + return ",".join([str(x) for x in self]) + + +def get_all_log_entries(cg): + log_entries = [] + log_rows = cg.client.read_log_entries() + for operation_id in range(cg.client.get_max_operation_id()): + try: + log_entries.append( + LogEntry(log_rows[operation_id], log_rows[operation_id]["timestamp"]) + ) + except KeyError: + continue + return log_entries diff --git a/pychunkedgraph/graph/subgraph.py b/pychunkedgraph/graph/subgraph.py new file mode 100644 index 000000000..ab2593175 --- /dev/null +++ b/pychunkedgraph/graph/subgraph.py @@ -0,0 +1,246 @@ +from typing import List +from typing import Dict +from typing import Tuple +from typing import Union +from typing import Iterable +from typing import Sequence +from typing import Optional + +import numpy as np + +from .edges import Edges +from .chunks.utils import normalize_bounding_box + + +class SubgraphProgress: + """ + Helper class to keep track of node relationships + while calling cg.get_subgraph(node_ids) + """ + + def __init__(self, meta, node_ids, return_layers, serializable): + from collections import defaultdict + + self.meta = meta + self.node_ids = node_ids + self.return_layers = return_layers + self.serializable = serializable + + self.node_to_subgraph = defaultdict(lambda: defaultdict(list)) + # "Frontier" of nodes that cg.get_children will be called on + self.cur_nodes = np.array(list(node_ids), dtype=np.uint64) + # Mapping of current frontier to self.node_ids + self.cur_nodes_to_original_nodes = dict( + zip(self.cur_nodes, self.cur_nodes) + ) + self.stop_layer = max(1, min(return_layers)) + self.create_initial_node_to_subgraph() + + def done_processing(self): + return self.cur_nodes is None or len(self.cur_nodes) == 0 + + def create_initial_node_to_subgraph(self): + """ + Create initial subgraph. We will incrementally populate after processing + each batch of children, and return it when there are no more to process. + """ + from .chunks.utils import get_chunk_layer + + for node_id in self.cur_nodes: + node_key = self.get_dict_key(node_id) + node_layer = get_chunk_layer(self.meta, node_id) + if node_layer in self.return_layers: + self.node_to_subgraph[node_key][node_layer].append([node_id]) + + def get_dict_key(self, node_id): + if self.serializable: + return str(node_id) + return node_id + + def process_batch_of_children(self, cur_nodes_children): + """ + Given children of self.cur_nodes, update subgraph and + produce next frontier (if any). + """ + from .chunks.utils import get_chunk_layers + + next_nodes_to_process = [] + next_nodes_to_original_nodes_keys = [] + next_nodes_to_original_nodes_values = [] + for cur_node, children in cur_nodes_children.items(): + children_layers = get_chunk_layers(self.meta, children) + continue_mask = children_layers > self.stop_layer + continue_children = children[continue_mask] + original_id = self.cur_nodes_to_original_nodes[np.uint64(cur_node)] + if len(continue_children) > 0: + # These nodes will be in next frontier + next_nodes_to_process.append(continue_children) + next_nodes_to_original_nodes_keys.append(continue_children) + next_nodes_to_original_nodes_values.append( + [original_id] * len(continue_children) + ) + for return_layer in self.return_layers: + # Update subgraph for each return_layer + children_at_layer = children[children_layers == return_layer] + if len(children_at_layer) > 0: + self.node_to_subgraph[self.get_dict_key(original_id)][ + return_layer + ].append(children_at_layer) + + if len(next_nodes_to_process) == 0: + self.cur_nodes = None + # We are done, so we can np.concatenate/flatten each entry in node_to_subgraph + self.flatten_subgraph() + else: + self.cur_nodes = np.concatenate(next_nodes_to_process) + self.cur_nodes_to_original_nodes = dict( + zip( + np.concatenate(next_nodes_to_original_nodes_keys), + np.concatenate(next_nodes_to_original_nodes_values), + ) + ) + + def flatten_subgraph(self): + from .types import empty_1d + + # Flatten each entry in node_to_subgraph before returning + for node_id in self.node_ids: + for return_layer in self.return_layers: + node_key = self.get_dict_key(node_id) + children_at_layer = self.node_to_subgraph[node_key][ + return_layer + ] + if len(children_at_layer) > 0: + self.node_to_subgraph[node_key][ + return_layer + ] = np.concatenate(children_at_layer) + else: + self.node_to_subgraph[node_key][return_layer] = empty_1d + + +def get_subgraph_nodes( + cg, + node_id_or_ids: Union[np.uint64, Iterable], + bbox: Optional[Sequence[Sequence[int]]] = None, + bbox_is_coordinate: bool = False, + return_layers: List = [2], + serializable: bool = False, + return_flattened: bool = False +) -> Tuple[Dict, Dict, Edges]: + single = False + node_ids = node_id_or_ids + bbox = normalize_bounding_box(cg.meta, bbox, bbox_is_coordinate) + if isinstance(node_id_or_ids, np.uint64) or isinstance(node_id_or_ids, int): + single = True + node_ids = [node_id_or_ids] + layer_nodes_d = _get_subgraph_multiple_nodes( + cg, + node_ids=node_ids, + bounding_box=bbox, + return_layers=return_layers, + serializable=serializable, + return_flattened=return_flattened + ) + if single: + if serializable: + return layer_nodes_d[str(node_id_or_ids)] + return layer_nodes_d[node_id_or_ids] + return layer_nodes_d + + +def get_subgraph_edges_and_leaves( + cg, + node_id_or_ids: Union[np.uint64, Iterable], + bbox: Optional[Sequence[Sequence[int]]] = None, + bbox_is_coordinate: bool = False, + edges_only: bool = False, + leaves_only: bool = False, +) -> Tuple[Dict, Dict, Edges]: + """Get the edges and/or leaves of the specified node_ids within the specified bounding box.""" + from .types import empty_1d + + node_ids = node_id_or_ids + bbox = normalize_bounding_box(cg.meta, bbox, bbox_is_coordinate) + if isinstance(node_id_or_ids, np.uint64) or isinstance(node_id_or_ids, int): + node_ids = [node_id_or_ids] + layer_nodes_d = _get_subgraph_multiple_nodes( + cg, node_ids, bbox, return_layers=[2], return_flattened=True + ) + level2_ids = [empty_1d] + for node_id in node_ids: + level2_ids.append(layer_nodes_d[node_id]) + level2_ids = np.concatenate(level2_ids) + if leaves_only: + return cg.get_children(level2_ids, flatten=True) + if edges_only: + return cg.get_l2_agglomerations(level2_ids, edges_only=True) + return cg.get_l2_agglomerations(level2_ids) + + +def _get_subgraph_multiple_nodes( + cg, + node_ids: Iterable[np.uint64], + bounding_box: Optional[Sequence[Sequence[int]]], + return_layers: Sequence[int], + serializable: bool = False, + return_flattened: bool = False +): + from collections import ChainMap + from multiwrapper.multiprocessing_utils import n_cpus + from multiwrapper.multiprocessing_utils import multithread_func + + from .utils.generic import mask_nodes_by_bounding_box + + assert len(return_layers) > 0 + + def _get_dict_key(raw_key): + if serializable: + return str(raw_key) + return raw_key + + def _get_subgraph_multiple_nodes_threaded( + node_ids_batch: Iterable[np.uint64], + ) -> List[np.uint64]: + children = cg.get_children(node_ids_batch) + if bounding_box is not None: + filtered_children = {} + for node_id, nodes_children in children.items(): + if cg.get_chunk_layer(node_id) == 2: + # All children will be in same chunk so no need to check + filtered_children[_get_dict_key(node_id)] = nodes_children + elif len(nodes_children) > 0: + bound_check_mask = mask_nodes_by_bounding_box( + cg.meta, nodes_children, bounding_box + ) + filtered_children[_get_dict_key(node_id)] = nodes_children[ + bound_check_mask + ] + return filtered_children + return children + + if bounding_box is not None: + bounding_box = np.array(bounding_box) + + subgraph = SubgraphProgress(cg.meta, node_ids, return_layers, serializable) + while not subgraph.done_processing(): + this_n_threads = min( + [int(len(subgraph.cur_nodes) // 50000) + 1, n_cpus] + ) + cur_nodes_child_maps = multithread_func( + _get_subgraph_multiple_nodes_threaded, + np.array_split(subgraph.cur_nodes, this_n_threads), + n_threads=this_n_threads, + debug=this_n_threads == 1, + ) + cur_nodes_children = dict(ChainMap(*cur_nodes_child_maps)) + subgraph.process_batch_of_children(cur_nodes_children) + + if return_flattened and len(return_layers) == 1: + for node_id in node_ids: + subgraph.node_to_subgraph[ + _get_dict_key(node_id) + ] = subgraph.node_to_subgraph[_get_dict_key(node_id)][ + return_layers[0] + ] + + return subgraph.node_to_subgraph \ No newline at end of file diff --git a/pychunkedgraph/graph/types.py b/pychunkedgraph/graph/types.py new file mode 100644 index 000000000..9a551f35c --- /dev/null +++ b/pychunkedgraph/graph/types.py @@ -0,0 +1,42 @@ +from typing import Dict +from typing import Iterable +from collections import namedtuple + +import numpy as np + +from .utils import basetypes + +empty_1d = np.empty(0, dtype=basetypes.NODE_ID) +empty_2d = np.empty((0, 2), dtype=basetypes.NODE_ID) + + +""" +An Agglomeration is syntactic sugar for representing +a level 2 ID and it's supervoxels and edges. +`in_edges` + edges between supervoxels belonging to the agglomeration. +`out_edges` + edges between supervoxels of agglomeration + and neighboring agglomeration. +`cross_edges_d` + dict of cross edges {layer: cross_edges_relevant_on_that_layer} +""" +_agglomeration_fields = ( + "node_id", + "supervoxels", + "in_edges", + "out_edges", + "cross_edges", + "cross_edges_d", +) +_agglomeration_defaults = ( + None, + empty_1d.copy(), + empty_2d.copy(), + empty_2d.copy(), + empty_2d.copy(), + {}, +) +Agglomeration = namedtuple( + "Agglomeration", _agglomeration_fields, defaults=_agglomeration_defaults +) diff --git a/pychunkedgraph/exporting/__init__.py b/pychunkedgraph/graph/utils/__init__.py similarity index 100% rename from pychunkedgraph/exporting/__init__.py rename to pychunkedgraph/graph/utils/__init__.py diff --git a/pychunkedgraph/backend/utils/basetypes.py b/pychunkedgraph/graph/utils/basetypes.py similarity index 85% rename from pychunkedgraph/backend/utils/basetypes.py rename to pychunkedgraph/graph/utils/basetypes.py index b191f748e..e55324e6a 100644 --- a/pychunkedgraph/backend/utils/basetypes.py +++ b/pychunkedgraph/graph/utils/basetypes.py @@ -1,7 +1,7 @@ import numpy as np -CHUNK_ID = SEGMENT_ID = NODE_ID = np.dtype('uint64').newbyteorder('L') +CHUNK_ID = SEGMENT_ID = NODE_ID = OPERATION_ID = np.dtype('uint64').newbyteorder('L') EDGE_AFFINITY = np.dtype('float32').newbyteorder('L') EDGE_AREA = np.dtype('uint64').newbyteorder('L') diff --git a/pychunkedgraph/graph/utils/flatgraph.py b/pychunkedgraph/graph/utils/flatgraph.py new file mode 100644 index 000000000..df469d728 --- /dev/null +++ b/pychunkedgraph/graph/utils/flatgraph.py @@ -0,0 +1,221 @@ +import fastremap +import numpy as np +from itertools import combinations, chain +from graph_tool import Graph, GraphView +from graph_tool import topology, search + + +def build_gt_graph( + edges, weights=None, is_directed=True, make_directed=False, hashed=False +): + """Builds a graph_tool graph + :param edges: n x 2 numpy array + :param weights: numpy array of length n + :param is_directed: bool + :param make_directed: bool + :param hashed: bool + :return: graph, capacities + """ + edges = np.array(edges, np.uint64) + if weights is not None: + assert len(weights) == len(edges) + weights = np.array(weights) + + unique_ids, edges = np.unique(edges, return_inverse=True) + edges = edges.reshape(-1, 2) + + edges = np.array(edges) + + if make_directed: + is_directed = True + edges = np.concatenate([edges, edges[:, [1, 0]]]) + + if weights is not None: + weights = np.concatenate([weights, weights]) + + weighted_graph = Graph(directed=is_directed) + weighted_graph.add_edge_list(edge_list=edges, hashed=hashed) + + if weights is not None: + cap = weighted_graph.new_edge_property("float", vals=weights) + else: + cap = None + return weighted_graph, cap, edges, unique_ids + + +def remap_ids_from_graph(graph_ids, unique_ids): + return unique_ids[graph_ids] + + +def connected_components(graph): + """Computes connected components of graph_tool graph + :param graph: graph_tool.Graph + :return: np.array of len == number of nodes + """ + assert isinstance(graph, Graph) + + cc_labels = topology.label_components(graph)[0].a + + if len(cc_labels) == 0: + return [] + + idx_sort = np.argsort(cc_labels) + _, idx_start = np.unique(cc_labels[idx_sort], return_index=True) + + return np.split(idx_sort, idx_start[1:]) + + +def team_paths_all_to_all(graph, capacity, team_vertex_ids): + dprop = capacity.copy() + # Use inverse affinity as the distance between vertices. + dprop.a = 1 / (dprop.a + np.finfo(np.float64).eps) + + paths_v = [] + paths_e = [] + path_affinities = [] + for i1, i2 in combinations(team_vertex_ids, 2): + v_list, e_list = topology.shortest_path( + graph, + source=graph.vertex(i1), + target=graph.vertex(i2), + weights=dprop, + ) + paths_v.append(v_list) + paths_e.append(e_list) + path_affinities.append(np.sum([dprop[e] for e in e_list])) + + return paths_v, paths_e, path_affinities + + +def neighboring_edges(graph, vertex_id): + """Returns vertex and edge lists of a seed vertex, in the same format as team_paths_all_to_all.""" + add_v = [] + add_e = [] + v0 = graph.vertex(vertex_id) + neibs = v0.out_neighbors() + for v in neibs: + add_v.append(v) + add_e.append(graph.edge(v, v0)) + return [add_v], [add_e], [1] + + +def intersect_nodes(paths_v_s, paths_v_y): + inds_s = np.unique([int(v) for v in chain.from_iterable(paths_v_s)]) + inds_y = np.unique([int(v) for v in chain.from_iterable(paths_v_y)]) + return np.intersect1d(inds_s, inds_y) + + +def harmonic_mean_paths(x): + return np.power(np.product(x), 1 / len(x)) + + +def compute_filtered_paths( + graph, + capacity, + team_vertex_ids, + intersect_vertices, +): + """Make a filtered GraphView that excludes intersect vertices and recompute shortest paths""" + intersection_filter = np.full(graph.num_vertices(), True) + intersection_filter[intersect_vertices] = False + vfilt = graph.new_vertex_property("bool", vals=intersection_filter) + gfilt = GraphView(graph, vfilt=vfilt) + paths_v, paths_e, path_affinities = team_paths_all_to_all( + gfilt, capacity, team_vertex_ids + ) + + # graph-tool will invalidate the vertex and edge properties if I don't rebase them on the main graph + # before tearing down the GraphView + new_paths_e = [] + for pth in paths_e: + # An empty path means vertices are not connected, which is disallowed + assert len(pth) > 0 + new_path = [] + for e in pth: + new_path.append(graph.edge(int(e.source()), int(e.target()))) + new_paths_e.append(new_path) + + new_paths_v = [] + for pth in paths_v: + new_path = [] + for v in pth: + new_path.append(graph.vertex(int(v))) + new_paths_v.append(new_path) + return new_paths_v, new_paths_e, path_affinities + + +def remove_overlapping_edges(paths_v_s, paths_e_s, paths_v_y, paths_e_y): + """Remove vertices that are in the paths from both teams""" + iverts = intersect_nodes(paths_v_s, paths_v_y) + if len(iverts) == 0: + return paths_e_s, paths_e_y, False + else: + path_e_s_out = [ + [ + e + for e in chain.from_iterable(paths_e_s) + if not np.any(np.isin([int(e.source()), int(e.target())], iverts)) + ] + ] + path_e_y_out = [ + [ + e + for e in chain.from_iterable(paths_e_y) + if not np.any(np.isin([int(e.source()), int(e.target())], iverts)) + ] + ] + return path_e_s_out, path_e_y_out, True + + +def check_connectedness(vertices, edges, expected_number=1): + """Returns True if the augmenting edges still form a single connected component""" + paths_inds = np.unique([int(v) for v in chain.from_iterable(vertices)]) + edge_list_inds = np.array( + [[int(e.source()), int(e.target())] for e in chain.from_iterable(edges)] + ) + + rmap = {v: ii for ii, v in enumerate(paths_inds)} + edge_list_remap = fastremap.remap(edge_list_inds, rmap) + + g2 = Graph(directed=False) + g2.add_vertex(n=len(paths_inds)) + if len(edge_list_remap) > 0: + g2.add_edge_list(np.atleast_2d(edge_list_remap)) + + _, count = topology.label_components(g2) + return len(count) == expected_number + + +def reverse_edge(graph, edge): + """Returns the complementary edge""" + return graph.edge(edge.target(), edge.source()) + + +def adjust_affinities(graph, capacity, paths_e, value=np.finfo(np.float32).max): + """Set affinity of a subset of paths to a particular value (typically the largest double).""" + capacity = capacity.copy() + + e_array = np.array( + [(int(e.source()), int(e.target())) for e in chain.from_iterable(paths_e)] + ) + if len(e_array) > 0: + e_array = np.sort(e_array, axis=1) + e_array = np.unique(e_array, axis=0) + e_list = [graph.edge(e[0], e[1]) for e in e_array] + else: + e_list = [] + + for edge in e_list: + capacity[edge] = value + # Capacity is a symmetric directed network + capacity[reverse_edge(graph, edge)] = value + return capacity + + +def flatten_edge_list(paths_e): + return np.unique( + [ + (int(e.source()), int(e.target())) + for e in chain.from_iterable(x for x in paths_e) + ] + ) diff --git a/pychunkedgraph/graph/utils/generic.py b/pychunkedgraph/graph/utils/generic.py new file mode 100644 index 000000000..9a2b6f979 --- /dev/null +++ b/pychunkedgraph/graph/utils/generic.py @@ -0,0 +1,186 @@ +""" +generic helper functions +TODO categorize properly +""" + + +import datetime +from typing import Dict +from typing import Iterable +from typing import Optional +from typing import Union +from typing import Sequence +from typing import Tuple +from collections import defaultdict + +import numpy as np +import pandas as pd +import pytz + +from ..chunks import utils as chunk_utils + + +def compute_indices_pandas(data) -> pd.Series: + """Computes indices of all unique entries + Make sure to remap your array to a dense range starting at zero + https://stackoverflow.com/questions/33281957/faster-alternative-to-numpy-where + :param data: np.ndarray + :return: pandas dataframe + """ + d = data.ravel() + f = lambda x: np.unravel_index(x.index, data.shape) + return pd.Series(d).groupby(d).apply(f) + + +def log_n(arr, n): + """Computes log to base n + :param arr: array or float + :param n: int + base + :return: return log_n(arr) + """ + if n == 2: + return np.log2(arr) + elif n == 10: + return np.log10(arr) + else: + return np.log(arr) / np.log(n) + + +def compute_bitmasks(n_layers: int, s_bits_atomic_layer: int = 8) -> Dict[int, int]: + """Computes the bitmasks for each layer. A bitmasks encodes how many bits + are used to store the chunk id in each dimension. The smallest number of + bits needed to encode this information is chosen. The layer id is always + encoded with 8 bits as this information is required a priori. + Currently, encoding of layer 1 is fixed to 8 bits. + :param n_layers: int + :param fan_out: int + :param s_bits_atomic_layer: int + :return: dict + layer -> bits for layer id + """ + bitmask_dict = {} + for i_layer in range(n_layers, 1, -1): + layer_exp = n_layers - i_layer + n_bits_for_layers = max(1, layer_exp) + if i_layer == 2: + if s_bits_atomic_layer < n_bits_for_layers: + err = f"{s_bits_atomic_layer} bits is not enough for encoding." + raise ValueError(err) + n_bits_for_layers = np.max([s_bits_atomic_layer, n_bits_for_layers]) + + n_bits_for_layers = int(n_bits_for_layers) + bitmask_dict[i_layer] = n_bits_for_layers + bitmask_dict[1] = bitmask_dict[2] + return bitmask_dict + + +def get_max_time(): + """Returns the (almost) max time in datetime.datetime + :return: datetime.datetime + """ + return datetime.datetime(9999, 12, 31, 23, 59, 59, 0) + + +def get_min_time(): + """Returns the min time in datetime.datetime + :return: datetime.datetime + """ + return datetime.datetime.strptime("01/01/00 00:00", "%d/%m/%y %H:%M") + + +def time_min(): + """Returns a minimal time stamp that still works with google + :return: datetime.datetime + """ + return datetime.datetime.strptime("01/01/00 00:00", "%d/%m/%y %H:%M") + + +def get_valid_timestamp(timestamp): + if timestamp is None: + timestamp = datetime.datetime.utcnow() + if timestamp.tzinfo is None: + timestamp = pytz.UTC.localize(timestamp) + # Comply to resolution of BigTables TimeRange + return _get_google_compatible_time_stamp(timestamp, round_up=False) + + +def get_bounding_box( + source_coords: Sequence[Sequence[int]], + sink_coords: Sequence[Sequence[int]], + bb_offset: Tuple[int, int, int] = (120, 120, 12), +): + if source_coords is None or sink_coords is None: + return + bb_offset = np.array(list(bb_offset)) + source_coords = np.array(source_coords) + sink_coords = np.array(sink_coords) + + coords = np.concatenate([source_coords, sink_coords]) + bounding_box = [np.min(coords, axis=0), np.max(coords, axis=0)] + bounding_box[0] -= bb_offset + bounding_box[1] += bb_offset + return bounding_box + + +def filter_failed_node_ids(row_ids, segment_ids, max_children_ids): + """filters node ids that were created by failed/in-complete jobs""" + sorting = np.argsort(segment_ids)[::-1] + row_ids = row_ids[sorting] + max_child_ids = np.array(max_children_ids)[sorting] + + counter = defaultdict(int) + max_child_ids_occ_so_far = np.zeros(len(max_child_ids), dtype=int) + for i_row in range(len(max_child_ids)): + max_child_ids_occ_so_far[i_row] = counter[max_child_ids[i_row]] + counter[max_child_ids[i_row]] += 1 + return row_ids[max_child_ids_occ_so_far == 0] + + +def _get_google_compatible_time_stamp( + time_stamp: datetime.datetime, round_up: bool = False +) -> datetime.datetime: + """Makes a datetime.datetime time stamp compatible with googles' services. + Google restricts the accuracy of time stamps to milliseconds. Hence, the + microseconds are cut of. By default, time stamps are rounded to the lower + number. + :param time_stamp: datetime.datetime + :param round_up: bool + :return: datetime.datetime + """ + micro_s_gap = datetime.timedelta(microseconds=time_stamp.microsecond % 1000) + if micro_s_gap == 0: + return time_stamp + if round_up: + time_stamp += datetime.timedelta(microseconds=1000) - micro_s_gap + else: + time_stamp -= micro_s_gap + return time_stamp + + +def mask_nodes_by_bounding_box( + meta, + nodes: Union[Iterable[np.uint64], np.uint64], + bounding_box: Optional[Sequence[Sequence[int]]] = None, +) -> Iterable[bool]: + if bounding_box is None: + return np.ones(len(nodes), bool) + else: + chunk_coordinates = np.array( + [chunk_utils.get_chunk_coordinates(meta, c) for c in nodes] + ) + layers = chunk_utils.get_chunk_layers(meta, nodes) + adapt_layers = layers - 2 + adapt_layers[adapt_layers < 0] = 0 + fanout = meta.graph_config.FANOUT + bounding_box_layer = ( + bounding_box[None] / (fanout ** adapt_layers)[:, None, None] + ) + bound_check = np.array( + [ + np.all(chunk_coordinates < bounding_box_layer[:, 1], axis=1), + np.all(chunk_coordinates + 1 > bounding_box_layer[:, 0], axis=1), + ] + ).T + + return np.all(bound_check, axis=1) \ No newline at end of file diff --git a/pychunkedgraph/graph/utils/id_helpers.py b/pychunkedgraph/graph/utils/id_helpers.py new file mode 100644 index 000000000..aa486ac84 --- /dev/null +++ b/pychunkedgraph/graph/utils/id_helpers.py @@ -0,0 +1,189 @@ +""" +Utils functions for node and segment IDs. +""" + +from typing import Optional +from typing import Sequence +from typing import Callable +from datetime import datetime + +import numpy as np + +from . import basetypes +from ..meta import ChunkedGraphMeta +from ..chunks import utils as chunk_utils + + +def get_segment_id_limit( + meta: ChunkedGraphMeta, node_or_chunk_id: basetypes.CHUNK_ID +) -> basetypes.SEGMENT_ID: + """Get maximum possible Segment ID for given Node ID or Chunk ID.""" + layer = chunk_utils.get_chunk_layer(meta, node_or_chunk_id) + chunk_offset = 64 - meta.graph_config.LAYER_ID_BITS - 3 * meta.bitmasks[layer] + return np.uint64(2 ** chunk_offset - 1) + + +def get_segment_id( + meta: ChunkedGraphMeta, node_id: basetypes.NODE_ID +) -> basetypes.SEGMENT_ID: + """Extract Segment ID from Node ID.""" + return node_id & get_segment_id_limit(meta, node_id) + + +def get_node_id( + meta: ChunkedGraphMeta, + segment_id: basetypes.SEGMENT_ID, + chunk_id: basetypes.CHUNK_ID = None, + layer: int = None, + x: int = None, + y: int = None, + z: int = None, +) -> basetypes.NODE_ID: + """ + (1) Build Node ID from Segment ID and Chunk ID + (2) Build Node ID from Segment ID, Layer, X, Y and Z components + """ + if chunk_id is not None: + return chunk_id | segment_id + else: + return chunk_utils.get_chunk_id(meta, layer=layer, x=x, y=y, z=z) | segment_id + + +def get_atomic_id_from_coord( + meta: ChunkedGraphMeta, + get_root: callable, + x: int, + y: int, + z: int, + parent_id: np.uint64, + n_tries: int = 5, + time_stamp: Optional[datetime] = None, +) -> np.uint64: + """Determines atomic id given a coordinate.""" + x = int(x / 2 ** meta.data_source.CV_MIP) + y = int(y / 2 ** meta.data_source.CV_MIP) + z = int(z) + + checked = [] + atomic_id = None + root_id = get_root(parent_id, time_stamp=time_stamp) + + for i_try in range(n_tries): + # Define block size -- increase by one each try + x_l = x - (i_try - 1) ** 2 + y_l = y - (i_try - 1) ** 2 + z_l = z - (i_try - 1) ** 2 + + x_h = x + 1 + (i_try - 1) ** 2 + y_h = y + 1 + (i_try - 1) ** 2 + z_h = z + 1 + (i_try - 1) ** 2 + + x_l = 0 if x_l < 0 else x_l + y_l = 0 if y_l < 0 else y_l + z_l = 0 if z_l < 0 else z_l + + # Get atomic ids from cloudvolume + atomic_id_block = meta.cv[x_l:x_h, y_l:y_h, z_l:z_h] + atomic_ids, atomic_id_count = np.unique(atomic_id_block, return_counts=True) + + # sort by frequency and discard those ids that have been checked + # previously + sorted_atomic_ids = atomic_ids[np.argsort(atomic_id_count)] + sorted_atomic_ids = sorted_atomic_ids[~np.in1d(sorted_atomic_ids, checked)] + + # For each candidate id check whether its root id corresponds to the + # given root id + for candidate_atomic_id in sorted_atomic_ids: + if candidate_atomic_id != 0: + ass_root_id = get_root(candidate_atomic_id, time_stamp=time_stamp) + if ass_root_id == root_id: + # atomic_id is not None will be our indicator that the + # search was successful + atomic_id = candidate_atomic_id + break + else: + checked.append(candidate_atomic_id) + if atomic_id is not None: + break + # Returns None if unsuccessful + return atomic_id + + +def get_atomic_ids_from_coords( + meta: ChunkedGraphMeta, + coordinates: Sequence[Sequence[int]], + parent_id: np.uint64, + parent_id_layer: int, + parent_ts: datetime, + get_roots: Callable, + max_dist_nm: int = 150, +) -> Sequence[np.uint64]: + """Retrieves supervoxel ids for multiple coords. + + :param coordinates: n x 3 np.ndarray of locations in voxel space + :param parent_id: parent id common to all coordinates at any layer + :param max_dist_nm: max distance explored + :return: supervoxel ids; returns None if no solution was found + """ + import fastremap + + if parent_id_layer == 1: + return np.array([parent_id] * len(coordinates), dtype=np.uint64) + + coordinates_nm = coordinates * np.array(meta.resolution) + # Define bounding box to be explored + max_dist_vx = np.ceil(max_dist_nm / meta.resolution).astype(dtype=np.int32) + bbox = np.array( + [ + np.min(coordinates, axis=0) - max_dist_vx, + np.max(coordinates, axis=0) + max_dist_vx + 1, + ] + ) + + local_sv_seg = meta.cv[ + bbox[0, 0] : bbox[1, 0], bbox[0, 1] : bbox[1, 1], bbox[0, 2] : bbox[1, 2] + ].squeeze() + + # limit get_roots calls to the relevant areas of the data + lower_bs = np.floor( + (np.array(coordinates_nm) - max_dist_nm) / np.array(meta.resolution) - bbox[0] + ).astype(np.int32) + upper_bs = np.ceil( + (np.array(coordinates_nm) + max_dist_nm) / np.array(meta.resolution) - bbox[0] + ).astype(np.int32) + local_sv_ids = [] + for lb, ub in zip(lower_bs, upper_bs): + local_sv_ids.extend( + fastremap.unique(local_sv_seg[lb[0] : ub[0], lb[1] : ub[1], lb[2] : ub[2]]) + ) + local_sv_ids = fastremap.unique(np.array(local_sv_ids, dtype=np.uint64)) + local_parent_ids = get_roots( + local_sv_ids, + time_stamp=parent_ts, + stop_layer=parent_id_layer, + fail_to_zero=True + ) + + local_parent_seg = fastremap.remap( + local_sv_seg, + dict(zip(local_sv_ids, local_parent_ids)), + preserve_missing_labels=True, + ) + + parent_id_locs_vx = np.array(np.where(local_parent_seg == parent_id)).T + if len(parent_id_locs_vx) == 0: + return None + + parent_id_locs_nm = (parent_id_locs_vx + bbox[0]) * np.array(meta.resolution) + # find closest supervoxel ids and check that they are closer than the limit + dist_mat = np.sqrt( + np.sum((parent_id_locs_nm[:, None] - coordinates_nm) ** 2, axis=-1) + ) + match_ids = np.argmin(dist_mat, axis=0) + matched_dists = np.array([dist_mat[idx, i] for i, idx in enumerate(match_ids)]) + if np.any(matched_dists > max_dist_nm): + return None + + local_coords = parent_id_locs_vx[match_ids] + matched_sv_ids = [local_sv_seg[tuple(c)] for c in local_coords] + return matched_sv_ids diff --git a/pychunkedgraph/backend/utils/serializers.py b/pychunkedgraph/graph/utils/serializers.py similarity index 62% rename from pychunkedgraph/backend/utils/serializers.py rename to pychunkedgraph/graph/utils/serializers.py index 401204642..09c0f63b0 100644 --- a/pychunkedgraph/backend/utils/serializers.py +++ b/pychunkedgraph/graph/utils/serializers.py @@ -1,18 +1,27 @@ from typing import Any, Iterable import json +import pickle + import numpy as np +import zstandard as zstd -class _Serializer(): - def __init__(self, serializer, deserializer, basetype=Any): +class _Serializer: + def __init__(self, serializer, deserializer, basetype=Any, compression_level=None): self._serializer = serializer self._deserializer = deserializer self._basetype = basetype + self._compression_level = compression_level def serialize(self, obj): - return self._serializer(obj) + content = self._serializer(obj) + if self._compression_level: + return zstd.ZstdCompressor(level=self._compression_level).compress(content) + return content def deserialize(self, obj): + if self._compression_level: + obj = zstd.ZstdDecompressor().decompressobj().decompress(obj) return self._deserializer(obj) @property @@ -30,11 +39,14 @@ def _deserialize(val, dtype, shape=None, order=None): return data.reshape(data.shape, order=order) return data - def __init__(self, dtype, shape=None, order=None): + def __init__(self, dtype, shape=None, order=None, compression_level=None): super().__init__( serializer=lambda x: x.newbyteorder(dtype.byteorder).tobytes(), - deserializer=lambda x: NumPyArray._deserialize(x, dtype, shape=shape, order=order), - basetype=dtype.type + deserializer=lambda x: NumPyArray._deserialize( + x, dtype, shape=shape, order=order + ), + basetype=dtype.type, + compression_level=compression_level, ) @@ -43,7 +55,7 @@ def __init__(self, dtype): super().__init__( serializer=lambda x: x.newbyteorder(dtype.byteorder).tobytes(), deserializer=lambda x: np.frombuffer(x, dtype=dtype)[0], - basetype=dtype.type + basetype=dtype.type, ) @@ -52,7 +64,7 @@ def __init__(self, encoding="utf-8"): super().__init__( serializer=lambda x: x.encode(encoding), deserializer=lambda x: x.decode(), - basetype=str + basetype=str, ) @@ -61,7 +73,16 @@ def __init__(self): super().__init__( serializer=lambda x: json.dumps(x).encode("utf-8"), deserializer=lambda x: json.loads(x.decode()), - basetype=str + basetype=str, + ) + + +class Pickle(_Serializer): + def __init__(self): + super().__init__( + serializer=lambda x: pickle.dumps(x), + deserializer=lambda x: pickle.loads(x), + basetype=str, ) @@ -70,7 +91,7 @@ def __init__(self): super().__init__( serializer=serialize_uint64, deserializer=deserialize_uint64, - basetype=np.uint64 + basetype=np.uint64, ) @@ -83,12 +104,16 @@ def pad_node_id(node_id: np.uint64) -> str: return "%.20d" % node_id -def serialize_uint64(node_id: np.uint64) -> bytes: +def serialize_uint64(node_id: np.uint64, counter=False, fake_edges=False) -> bytes: """ Serializes an id to be ingested by a bigtable table row :param node_id: int :return: str """ + if counter: + return serialize_key("i%s" % pad_node_id(node_id)) # type: ignore + if fake_edges: + return serialize_key("f%s" % pad_node_id(node_id)) # type: ignore return serialize_key(pad_node_id(node_id)) # type: ignore @@ -98,17 +123,18 @@ def serialize_uint64s_to_regex(node_ids: Iterable[np.uint64]) -> bytes: :param node_id: int :return: str """ - node_id_str = "".join(["%s|" % pad_node_id(node_id) - for node_id in node_ids])[:-1] + node_id_str = "".join(["%s|" % pad_node_id(node_id) for node_id in node_ids])[:-1] return serialize_key(node_id_str) # type: ignore -def deserialize_uint64(node_id: bytes) -> np.uint64: +def deserialize_uint64(node_id: bytes, fake_edges=False) -> np.uint64: """ De-serializes a node id from a BigTable row :param node_id: bytes :return: np.uint64 """ + if fake_edges: + return np.uint64(node_id[1:].decode()) # type: ignore return np.uint64(node_id.decode()) # type: ignore diff --git a/pychunkedgraph/ingest/__init__.py b/pychunkedgraph/ingest/__init__.py index e69de29bb..b3d832d5e 100644 --- a/pychunkedgraph/ingest/__init__.py +++ b/pychunkedgraph/ingest/__init__.py @@ -0,0 +1,32 @@ +from collections import namedtuple + + +_cluster_ingest_config_fields = ( + "ATOMIC_Q_NAME", + "ATOMIC_Q_LIMIT", + "ATOMIC_Q_INTERVAL", +) +_cluster_ingest_defaults = ( + "l2", + 100000, + 120, +) +ClusterIngestConfig = namedtuple( + "ClusterIngestConfig", + _cluster_ingest_config_fields, + defaults=_cluster_ingest_defaults, +) + + +_ingestconfig_fields = ( + "CLUSTER", # cluster config + "AGGLOMERATION", + "WATERSHED", + "USE_RAW_EDGES", + "USE_RAW_COMPONENTS", + "TEST_RUN", +) +_ingestconfig_defaults = (None, None, None, False, False, False) +IngestConfig = namedtuple( + "IngestConfig", _ingestconfig_fields, defaults=_ingestconfig_defaults +) diff --git a/pychunkedgraph/ingest/cli.py b/pychunkedgraph/ingest/cli.py new file mode 100644 index 000000000..7668e8f24 --- /dev/null +++ b/pychunkedgraph/ingest/cli.py @@ -0,0 +1,194 @@ +""" +cli for running ingest +""" + +from os import environ +from time import sleep + +import click +import yaml +from flask.cli import AppGroup +from rq import Queue + +from .manager import IngestionManager +from .utils import bootstrap +from .cluster import randomize_grid_points +from ..graph.chunkedgraph import ChunkedGraph +from ..utils.redis import get_redis_connection +from ..utils.redis import keys as r_keys +from ..utils.general import chunked + +ingest_cli = AppGroup("ingest") + + +def init_ingest_cmds(app): + app.cli.add_command(ingest_cli) + + +@ingest_cli.command("flush_redis") +def flush_redis(): + """FLush redis db.""" + redis = get_redis_connection() + redis.flushdb() + + +@ingest_cli.command("graph") +@click.argument("graph_id", type=str) +@click.argument("dataset", type=click.Path(exists=True)) +@click.option("--raw", is_flag=True) +@click.option("--test", is_flag=True) +@click.option("--retry", is_flag=True) +def ingest_graph( + graph_id: str, dataset: click.Path, raw: bool, test: bool, retry: bool +): + """ + Main ingest command. + Takes ingest config from a yaml file and queues atomic tasks. + """ + from .cluster import enqueue_atomic_tasks + + with open(dataset, "r") as stream: + config = yaml.safe_load(stream) + + meta, ingest_config, client_info = bootstrap( + graph_id, + config=config, + raw=raw, + test_run=test, + ) + cg = ChunkedGraph(meta=meta, client_info=client_info) + if not retry: + cg.create() + enqueue_atomic_tasks(IngestionManager(ingest_config, meta)) + + +@ingest_cli.command("imanager") +@click.argument("graph_id", type=str) +@click.argument("dataset", type=click.Path(exists=True)) +@click.option("--raw", is_flag=True) +def pickle_imanager(graph_id: str, dataset: click.Path, raw: bool): + """ + Load ingest config into redis server. + Must only be used if ingest config is lost/corrupted during ingest. + """ + with open(dataset, "r") as stream: + try: + config = yaml.safe_load(stream) + except yaml.YAMLError as exc: + print(exc) + + meta, ingest_config, _ = bootstrap(graph_id, config=config, raw=raw) + imanager = IngestionManager(ingest_config, meta) + imanager.redis + + +@ingest_cli.command("layer") +@click.argument("parent_layer", type=int) +def queue_layer(parent_layer): + """ + Queue all chunk tasks at a given layer. + Must be used when all the chunks at `parent_layer - 1` have completed. + """ + from itertools import product + import numpy as np + from .cluster import create_parent_chunk + from .utils import chunk_id_str + + assert parent_layer > 2, "This command is for layers 3 and above." + redis = get_redis_connection() + imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) + + if parent_layer == imanager.cg_meta.layer_count: + chunk_coords = [(0, 0, 0)] + else: + bounds = imanager.cg_meta.layer_chunk_bounds[parent_layer] + chunk_coords = randomize_grid_points(*bounds) + + def get_chunks_not_done(coords: list) -> list: + """check for set membership in redis in batches""" + coords_strs = ["_".join(map(str, coord)) for coord in coords] + try: + completed = imanager.redis.smismember(f"{parent_layer}c", coords_strs) + except Exception: + return coords + return [coord for coord, c in zip(coords, completed) if not c] + + batch_size = int(environ.get("JOB_BATCH_SIZE", 10000)) + batches = chunked(chunk_coords, batch_size) + q = imanager.get_task_queue(f"l{parent_layer}") + + for batch in batches: + _coords = get_chunks_not_done(batch) + # buffer for optimal use of redis memory + if len(q) > int(environ.get("QUEUE_SIZE", 100000)): + interval = int(environ.get("QUEUE_INTERVAL", 300)) + sleep(interval) + + job_datas = [] + for chunk_coord in _coords: + job_datas.append( + Queue.prepare_data( + create_parent_chunk, + args=(parent_layer, chunk_coord), + result_ttl=0, + job_id=chunk_id_str(parent_layer, chunk_coord), + timeout=f"{int(parent_layer * parent_layer)}m", + ) + ) + q.enqueue_many(job_datas) + + +@ingest_cli.command("status") +def ingest_status(): + """Print ingest status to console by layer.""" + redis = get_redis_connection() + imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) + layers = range(2, imanager.cg_meta.layer_count + 1) + for layer, layer_count in zip(layers, imanager.cg_meta.layer_chunk_counts): + completed = redis.scard(f"{layer}c") + print(f"{layer}\t: {completed} / {layer_count}") + + +@ingest_cli.command("chunk") +@click.argument("queue", type=str) +@click.argument("chunk_info", nargs=4, type=int) +def ingest_chunk(queue: str, chunk_info): + """Manually queue chunk when a job is stuck for whatever reason.""" + from .cluster import _create_atomic_chunk + from .cluster import create_parent_chunk + from .utils import chunk_id_str + + redis = get_redis_connection() + imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) + layer = chunk_info[0] + coords = chunk_info[1:] + queue = imanager.get_task_queue(queue) + if layer == 2: + func = _create_atomic_chunk + args = (coords,) + else: + func = create_parent_chunk + args = (layer, coords) + queue.enqueue( + func, + job_id=chunk_id_str(layer, coords), + job_timeout=f"{int(layer * layer)}m", + result_ttl=0, + args=args, + ) + + +@ingest_cli.command("chunk_local") +@click.argument("graph_id", type=str) +@click.argument("chunk_info", nargs=4, type=int) +@click.option("--n_threads", type=int, default=1) +def ingest_chunk_local(graph_id: str, chunk_info, n_threads: int): + """Manually ingest a chunk on a local machine.""" + from .create.abstract_layers import add_layer + from .cluster import _create_atomic_chunk + + if chunk_info[0] == 2: + _create_atomic_chunk(chunk_info[1:]) + else: + cg = ChunkedGraph(graph_id=graph_id) + add_layer(cg, chunk_info[0], chunk_info[1:], n_threads=n_threads) diff --git a/pychunkedgraph/ingest/cluster.py b/pychunkedgraph/ingest/cluster.py new file mode 100644 index 000000000..cf9417024 --- /dev/null +++ b/pychunkedgraph/ingest/cluster.py @@ -0,0 +1,195 @@ +""" +Ingest / create chunkedgraph with workers. +""" + +from typing import Sequence, Tuple + +import numpy as np + +from .utils import chunk_id_str +from .manager import IngestionManager +from .common import get_atomic_chunk_data +from .ran_agglomeration import get_active_edges +from .create.atomic_layer import add_atomic_edges +from .create.abstract_layers import add_layer +from ..graph.meta import ChunkedGraphMeta +from ..graph.chunks.hierarchy import get_children_chunk_coords +from ..utils.redis import keys as r_keys +from ..utils.redis import get_redis_connection + + +def _post_task_completion(imanager: IngestionManager, layer: int, coords: np.ndarray): + from os import environ + + chunk_str = "_".join(map(str, coords)) + # mark chunk as completed - "c" + imanager.redis.sadd(f"{layer}c", chunk_str) + + if environ.get("DO_NOT_AUTOQUEUE_PARENT_CHUNKS", None) is not None: + return + + parent_layer = layer + 1 + if parent_layer > imanager.cg_meta.layer_count: + return + + parent_coords = np.array(coords, int) // imanager.cg_meta.graph_config.FANOUT + parent_id_str = chunk_id_str(parent_layer, parent_coords) + imanager.redis.sadd(parent_id_str, chunk_str) + + parent_chunk_str = "_".join(map(str, parent_coords)) + if not imanager.redis.hget(parent_layer, parent_chunk_str): + # cache children chunk count + # checked by tracker worker to enqueue parent chunk + children_count = len( + get_children_chunk_coords(imanager.cg_meta, parent_layer, parent_coords) + ) + imanager.redis.hset(parent_layer, parent_chunk_str, children_count) + + tracker_queue = imanager.get_task_queue(f"t{layer}") + tracker_queue.enqueue( + enqueue_parent_task, + job_id=f"t{layer}_{chunk_str}", + job_timeout=f"30s", + result_ttl=0, + args=( + parent_layer, + parent_coords, + ), + ) + + +def enqueue_parent_task( + parent_layer: int, + parent_coords: Sequence[int], +): + redis = get_redis_connection() + imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) + parent_id_str = chunk_id_str(parent_layer, parent_coords) + parent_chunk_str = "_".join(map(str, parent_coords)) + + children_done = redis.scard(parent_id_str) + # if zero then this key was deleted and parent already queued. + if children_done == 0: + print("parent already queued.") + return + + # if the previous layer is complete + # no need to check children progress for each parent chunk + child_layer = parent_layer - 1 + child_layer_done = redis.scard(f"{child_layer}c") + child_layer_count = imanager.cg_meta.layer_chunk_counts[child_layer - 2] + child_layer_finished = child_layer_done == child_layer_count + + if not child_layer_finished: + children_count = int(redis.hget(parent_layer, parent_chunk_str).decode("utf-8")) + if children_done != children_count: + print("children not done.") + return + + queue = imanager.get_task_queue(f"l{parent_layer}") + queue.enqueue( + create_parent_chunk, + job_id=parent_id_str, + job_timeout=f"{int(parent_layer * parent_layer)}m", + result_ttl=0, + args=( + parent_layer, + parent_coords, + ), + ) + redis.hdel(parent_layer, parent_chunk_str) + redis.delete(parent_id_str) + + +def create_parent_chunk( + parent_layer: int, + parent_coords: Sequence[int], +) -> None: + redis = get_redis_connection() + imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) + add_layer( + imanager.cg, + parent_layer, + parent_coords, + get_children_chunk_coords( + imanager.cg_meta, + parent_layer, + parent_coords, + ), + ) + _post_task_completion(imanager, parent_layer, parent_coords) + + +def randomize_grid_points(X: int, Y: int, Z: int) -> Tuple[int, int, int]: + indices = np.arange(X * Y * Z) + np.random.shuffle(indices) + for index in indices: + yield np.unravel_index(index, (X, Y, Z)) + + +def enqueue_atomic_tasks(imanager: IngestionManager): + from os import environ + from time import sleep + from rq import Queue as RQueue + + chunk_coords = _get_test_chunks(imanager.cg.meta) + chunk_count = len(chunk_coords) + if not imanager.config.TEST_RUN: + atomic_chunk_bounds = imanager.cg_meta.layer_chunk_bounds[2] + chunk_coords = randomize_grid_points(*atomic_chunk_bounds) + chunk_count = imanager.cg_meta.layer_chunk_counts[0] + + print(f"total chunk count: {chunk_count}, queuing...") + batch_size = int(environ.get("L2JOB_BATCH_SIZE", 1000)) + + job_datas = [] + for chunk_coord in chunk_coords: + q = imanager.get_task_queue(imanager.config.CLUSTER.ATOMIC_Q_NAME) + # buffer for optimal use of redis memory + if len(q) > imanager.config.CLUSTER.ATOMIC_Q_LIMIT: + print(f"Sleeping {imanager.config.CLUSTER.ATOMIC_Q_INTERVAL}s...") + sleep(imanager.config.CLUSTER.ATOMIC_Q_INTERVAL) + + x, y, z = chunk_coord + chunk_str = f"{x}_{y}_{z}" + if imanager.redis.sismember("2c", chunk_str): + # already done, skip + continue + job_datas.append( + RQueue.prepare_data( + _create_atomic_chunk, + args=(chunk_coord,), + timeout=environ.get("L2JOB_TIMEOUT", "3m"), + result_ttl=0, + job_id=chunk_id_str(2, chunk_coord), + ) + ) + if len(job_datas) % batch_size == 0: + q.enqueue_many(job_datas) + job_datas = [] + q.enqueue_many(job_datas) + + +def _create_atomic_chunk(coords: Sequence[int]): + """Creates single atomic chunk""" + redis = get_redis_connection() + imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) + coords = np.array(list(coords), dtype=int) + chunk_edges_all, mapping = get_atomic_chunk_data(imanager, coords) + chunk_edges_active, isolated_ids = get_active_edges(chunk_edges_all, mapping) + add_atomic_edges(imanager.cg, coords, chunk_edges_active, isolated=isolated_ids) + if imanager.config.TEST_RUN: + # print for debugging + for k, v in chunk_edges_all.items(): + print(k, len(v)) + for k, v in chunk_edges_active.items(): + print(f"active_{k}", len(v)) + _post_task_completion(imanager, 2, coords) + + +def _get_test_chunks(meta: ChunkedGraphMeta): + """Chunks at center of the dataset most likely not to be empty""" + parent_coords = np.array(meta.layer_chunk_bounds[3]) // 2 + return get_children_chunk_coords(meta, 3, parent_coords) + # f = lambda r1, r2, r3: np.array(np.meshgrid(r1, r2, r3), dtype=int).T.reshape(-1, 3) + # return f((x, x + 1), (y, y + 1), (z, z + 1)) diff --git a/pychunkedgraph/ingest/common.py b/pychunkedgraph/ingest/common.py new file mode 100644 index 000000000..dccf58602 --- /dev/null +++ b/pychunkedgraph/ingest/common.py @@ -0,0 +1,61 @@ +from typing import Dict +from typing import Tuple +from typing import Sequence + +from .manager import IngestionManager +from .ran_agglomeration import read_raw_edge_data +from .ran_agglomeration import read_raw_agglomeration_data +from ..graph import ChunkedGraph +from ..io.edges import get_chunk_edges +from ..io.components import get_chunk_components + + +def get_atomic_chunk_data( + imanager: IngestionManager, coord: Sequence[int] +) -> Tuple[Dict, Dict]: + """ + Helper to read either raw data or processed data + If reading from raw data, save it as processed data + """ + chunk_edges = ( + read_raw_edge_data(imanager, coord) + if imanager.config.USE_RAW_EDGES + else get_chunk_edges(imanager.cg_meta.data_source.EDGES, [coord]) + ) + + _check_edges_direction(chunk_edges, imanager.cg, coord) + + mapping = ( + read_raw_agglomeration_data(imanager, coord) + if imanager.config.USE_RAW_COMPONENTS + else get_chunk_components(imanager.cg_meta.data_source.COMPONENTS, coord) + ) + return chunk_edges, mapping + + +def _check_edges_direction( + chunk_edges: dict, cg: ChunkedGraph, coord: Sequence[int] +) -> None: + """ + For between and cross chunk edges: + Checks and flips edges such that nodes1 are always within a chunk and nodes2 outside the chunk. + Where nodes1 = edges[:,0] and nodes2 = edges[:,1]. + """ + import numpy as np + from ..graph.edges import Edges + from ..graph.edges import EDGE_TYPES + + x, y, z = coord + chunk_id = cg.get_chunk_id(layer=1, x=x, y=y, z=z) + for edge_type in [EDGE_TYPES.between_chunk, EDGE_TYPES.cross_chunk]: + edges = chunk_edges[edge_type] + e1 = edges.node_ids1 + e2 = edges.node_ids2 + + e2_chunk_ids = cg.get_chunk_ids_from_node_ids(e2) + mask = e2_chunk_ids == chunk_id + e1[mask], e2[mask] = e2[mask], e1[mask] + + e1_chunk_ids = cg.get_chunk_ids_from_node_ids(e1) + mask = e1_chunk_ids == chunk_id + assert np.all(mask), "all IDs must belong to same chunk" diff --git a/pychunkedgraph/ingest/create/__init__.py b/pychunkedgraph/ingest/create/__init__.py new file mode 100644 index 000000000..009da30ed --- /dev/null +++ b/pychunkedgraph/ingest/create/__init__.py @@ -0,0 +1,3 @@ +""" +modules for chunkedgraph initialization/creation +""" diff --git a/pychunkedgraph/ingest/create/abstract_layers.py b/pychunkedgraph/ingest/create/abstract_layers.py new file mode 100644 index 000000000..529a6846f --- /dev/null +++ b/pychunkedgraph/ingest/create/abstract_layers.py @@ -0,0 +1,247 @@ +""" +Functions for creating parents in level 3 and above +""" + +import time +import math +import datetime +import multiprocessing as mp +from collections import defaultdict +from typing import Optional +from typing import Sequence +from typing import List + +import numpy as np +from multiwrapper import multiprocessing_utils as mu + +from ...graph import types +from ...graph import attributes +from ...utils.general import chunked +from ...graph.utils import flatgraph +from ...graph.utils import basetypes +from ...graph.utils import serializers +from ...graph.chunkedgraph import ChunkedGraph +from ...graph.utils.generic import get_valid_timestamp +from ...graph.utils.generic import filter_failed_node_ids +from ...graph.chunks.hierarchy import get_children_chunk_coords +from ...graph.connectivity.cross_edges import get_children_chunk_cross_edges +from ...graph.connectivity.cross_edges import get_chunk_nodes_cross_edge_layer + + +def add_layer( + cg: ChunkedGraph, + layer_id: int, + parent_coords: Sequence[int], + children_coords: Sequence[Sequence[int]] = np.array([]), + *, + time_stamp: Optional[datetime.datetime] = None, + n_threads: int = 4, +) -> None: + if not children_coords.size: + children_coords = get_children_chunk_coords(cg.meta, layer_id, parent_coords) + children_ids = _read_children_chunks(cg, layer_id, children_coords, n_threads > 1) + edge_ids = get_children_chunk_cross_edges( + cg, layer_id, parent_coords, use_threads=n_threads > 1 + ) + + print("children_coords", children_coords.size, layer_id, parent_coords) + print( + "n e", len(children_ids), len(edge_ids), layer_id, parent_coords, + ) + + node_layers = cg.get_chunk_layers(children_ids) + edge_layers = cg.get_chunk_layers(np.unique(edge_ids)) + assert np.all(node_layers < layer_id), "invalid node layers" + assert np.all(edge_layers < layer_id), "invalid edge layers" + # Extract connected components + # isolated_node_mask = ~np.in1d(children_ids, np.unique(edge_ids)) + # add_node_ids = children_ids[isolated_node_mask].squeeze() + add_edge_ids = np.vstack([children_ids, children_ids]).T + + edge_ids = list(edge_ids) + edge_ids.extend(add_edge_ids) + graph, _, _, graph_ids = flatgraph.build_gt_graph(edge_ids, make_directed=True) + ccs = flatgraph.connected_components(graph) + print("ccs", len(ccs)) + _write_connected_components( + cg, + layer_id, + parent_coords, + ccs, + graph_ids, + get_valid_timestamp(time_stamp), + n_threads > 1, + ) + return f"{layer_id}_{'_'.join(map(str, parent_coords))}" + + +def _read_children_chunks( + cg: ChunkedGraph, layer_id, children_coords, use_threads=True +): + if not use_threads: + children_ids = [types.empty_1d] + for child_coord in children_coords: + children_ids.append(_read_chunk([], cg, layer_id - 1, child_coord)) + return np.concatenate(children_ids) + + print("_read_children_chunks") + with mp.Manager() as manager: + children_ids_shared = manager.list() + multi_args = [] + for child_coord in children_coords: + multi_args.append( + ( + children_ids_shared, + cg.get_serialized_info(), + layer_id - 1, + child_coord, + ) + ) + mu.multiprocess_func( + _read_chunk_helper, + multi_args, + n_threads=min(len(multi_args), mp.cpu_count()), + ) + print("_read_children_chunks done") + return np.concatenate(children_ids_shared) + + +def _read_chunk_helper(args): + children_ids_shared, cg_info, layer_id, chunk_coord = args + cg = ChunkedGraph(**cg_info) + _read_chunk(children_ids_shared, cg, layer_id, chunk_coord) + + +def _read_chunk(children_ids_shared, cg: ChunkedGraph, layer_id: int, chunk_coord): + print(f"_read_chunk {layer_id}, {chunk_coord}") + x, y, z = chunk_coord + range_read = cg.range_read_chunk( + cg.get_chunk_id(layer=layer_id, x=x, y=y, z=z), + properties=attributes.Hierarchy.Child, + ) + row_ids = [] + max_children_ids = [] + for row_id, row_data in range_read.items(): + row_ids.append(row_id) + max_children_ids.append(np.max(row_data[0].value)) + row_ids = np.array(row_ids, dtype=basetypes.NODE_ID) + segment_ids = np.array([cg.get_segment_id(r_id) for r_id in row_ids]) + + row_ids = filter_failed_node_ids(row_ids, segment_ids, max_children_ids) + children_ids_shared.append(row_ids) + print(f"_read_chunk {layer_id}, {chunk_coord} done {len(row_ids)}") + return row_ids + + +def _write_connected_components( + cg: ChunkedGraph, + layer_id: int, + parent_coords, + ccs, + graph_ids, + time_stamp, + use_threads=True, +) -> None: + if not ccs: + return + + node_layer_d_shared = {} + if layer_id < cg.meta.layer_count: + print("getting node_layer_d_shared") + node_layer_d_shared = get_chunk_nodes_cross_edge_layer( + cg, layer_id, parent_coords, use_threads=use_threads + ) + + print("node_layer_d_shared", len(node_layer_d_shared)) + + ccs_with_node_ids = [] + for cc in ccs: + ccs_with_node_ids.append(graph_ids[cc]) + + if not use_threads: + _write( + cg, + layer_id, + parent_coords, + ccs_with_node_ids, + node_layer_d_shared, + time_stamp, + use_threads=use_threads, + ) + return + + task_size = int(math.ceil(len(ccs_with_node_ids) / mp.cpu_count() / 10)) + chunked_ccs = chunked(ccs_with_node_ids, task_size) + cg_info = cg.get_serialized_info() + multi_args = [] + for ccs in chunked_ccs: + multi_args.append( + (cg_info, layer_id, parent_coords, ccs, node_layer_d_shared, time_stamp) + ) + mu.multiprocess_func( + _write_components_helper, + multi_args, + n_threads=min(len(multi_args), mp.cpu_count()), + ) + + +def _write_components_helper(args): + print("running _write_components_helper") + cg_info, layer_id, parent_coords, ccs, node_layer_d_shared, time_stamp = args + cg = ChunkedGraph(**cg_info) + _write(cg, layer_id, parent_coords, ccs, node_layer_d_shared, time_stamp) + + +def _write( + cg, layer_id, parent_coords, ccs, node_layer_d_shared, time_stamp, use_threads=True +): + parent_layer_ids = range(layer_id, cg.meta.layer_count + 1) + cc_connections = {l: [] for l in parent_layer_ids} + for node_ids in ccs: + layer = layer_id + if len(node_ids) == 1: + layer = node_layer_d_shared.get(node_ids[0], cg.meta.layer_count) + cc_connections[layer].append(node_ids) + + rows = [] + x, y, z = parent_coords + parent_chunk_id = cg.get_chunk_id(layer=layer_id, x=x, y=y, z=z) + parent_chunk_id_dict = cg.get_parent_chunk_id_dict(parent_chunk_id) + + # Iterate through layers + for parent_layer_id in parent_layer_ids: + if len(cc_connections[parent_layer_id]) == 0: + continue + + parent_chunk_id = parent_chunk_id_dict[parent_layer_id] + reserved_parent_ids = cg.id_client.create_node_ids( + parent_chunk_id, + size=len(cc_connections[parent_layer_id]), + root_chunk=parent_layer_id == cg.meta.layer_count and use_threads, + ) + + for i_cc, node_ids in enumerate(cc_connections[parent_layer_id]): + parent_id = reserved_parent_ids[i_cc] + for node_id in node_ids: + rows.append( + cg.client.mutate_row( + serializers.serialize_uint64(node_id), + {attributes.Hierarchy.Parent: parent_id}, + time_stamp=time_stamp, + ) + ) + + rows.append( + cg.client.mutate_row( + serializers.serialize_uint64(parent_id), + {attributes.Hierarchy.Child: node_ids}, + time_stamp=time_stamp, + ) + ) + + if len(rows) > 100000: + cg.client.write(rows) + print("wrote rows", len(rows), layer_id, parent_coords) + rows = [] + cg.client.write(rows) + print("wrote rows", len(rows), layer_id, parent_coords) diff --git a/pychunkedgraph/ingest/create/atomic_layer.py b/pychunkedgraph/ingest/create/atomic_layer.py new file mode 100644 index 000000000..4fa1f1688 --- /dev/null +++ b/pychunkedgraph/ingest/create/atomic_layer.py @@ -0,0 +1,147 @@ +""" +Functions for creating atomic nodes and their level 2 abstract parents +""" + +import datetime +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence + +import pytz +import numpy as np + +from ...graph import attributes +from ...graph.chunkedgraph import ChunkedGraph +from ...graph.utils import basetypes +from ...graph.utils import serializers +from ...graph.edges import Edges +from ...graph.edges import EDGE_TYPES +from ...graph.utils.generic import compute_indices_pandas +from ...graph.utils.generic import get_valid_timestamp +from ...graph.utils.flatgraph import build_gt_graph +from ...graph.utils.flatgraph import connected_components + + +def add_atomic_edges( + cg: ChunkedGraph, + chunk_coord: np.ndarray, + chunk_edges_d: Dict[str, Edges], + isolated: Sequence[int], + time_stamp: Optional[datetime.datetime] = None, +): + chunk_node_ids, chunk_edge_ids = _get_chunk_nodes_and_edges(chunk_edges_d, isolated) + if not chunk_node_ids.size: + return + + chunk_ids = cg.get_chunk_ids_from_node_ids(chunk_node_ids) + assert len(np.unique(chunk_ids)) == 1 + + graph, _, _, unique_ids = build_gt_graph(chunk_edge_ids, make_directed=True) + ccs = connected_components(graph) + + parent_chunk_id = cg.get_chunk_id( + layer=2, x=chunk_coord[0], y=chunk_coord[1], z=chunk_coord[2] + ) + parent_ids = cg.id_client.create_node_ids(parent_chunk_id, size=len(ccs)) + + sparse_indices, remapping = _get_remapping(chunk_edges_d) + time_stamp = get_valid_timestamp(time_stamp) + nodes = [] + for i_cc, component in enumerate(ccs): + _nodes = _process_component( + cg, + chunk_edges_d, + parent_ids[i_cc], + unique_ids[component], + sparse_indices, + remapping, + time_stamp, + ) + nodes.extend(_nodes) + if len(nodes) > 100000: + cg.client.write(nodes) + nodes = [] + cg.client.write(nodes) + + +def _get_chunk_nodes_and_edges(chunk_edges_d: dict, isolated_ids: Sequence[int]): + """ + in-chunk edges and nodes_ids + """ + isolated_nodes_self_edges = np.vstack([isolated_ids, isolated_ids]).T + node_ids = [isolated_ids] + edge_ids = [isolated_nodes_self_edges] + for edge_type in EDGE_TYPES: + edges = chunk_edges_d[edge_type] + node_ids.append(edges.node_ids1) + if edge_type == EDGE_TYPES.in_chunk: + node_ids.append(edges.node_ids2) + edge_ids.append(edges.get_pairs()) + + chunk_node_ids = np.unique(np.concatenate(node_ids)) + edge_ids.append(np.vstack([chunk_node_ids, chunk_node_ids]).T) + return (chunk_node_ids, np.concatenate(edge_ids)) + + +def _get_remapping(chunk_edges_d: dict): + """ + TODO add logic explanation + """ + sparse_indices = {} + remapping = {} + for edge_type in [EDGE_TYPES.between_chunk, EDGE_TYPES.cross_chunk]: + edges = chunk_edges_d[edge_type].get_pairs() + u_ids, inv_ids = np.unique(edges, return_inverse=True) + mapped_ids = np.arange(len(u_ids), dtype=np.int32) + remapped_arr = mapped_ids[inv_ids].reshape(edges.shape) + sparse_indices[edge_type] = compute_indices_pandas(remapped_arr) + remapping[edge_type] = dict(zip(u_ids, mapped_ids)) + return sparse_indices, remapping + + +def _process_component( + cg, chunk_edges_d, parent_id, node_ids, sparse_indices, remapping, time_stamp, +): + nodes = [] + chunk_out_edges = [] # out = between + cross + for node_id in node_ids: + _edges = _get_outgoing_edges(node_id, chunk_edges_d, sparse_indices, remapping) + chunk_out_edges.append(_edges) + val_dict = {attributes.Hierarchy.Parent: parent_id} + r_key = serializers.serialize_uint64(node_id) + nodes.append(cg.client.mutate_row(r_key, val_dict, time_stamp=time_stamp)) + + chunk_out_edges = np.concatenate(chunk_out_edges) + cce_layers = cg.get_cross_chunk_edges_layer(chunk_out_edges) + u_cce_layers = np.unique(cce_layers) + + val_dict = {attributes.Hierarchy.Child: node_ids} + for cc_layer in u_cce_layers: + layer_out_edges = chunk_out_edges[cce_layers == cc_layer] + if layer_out_edges.size: + col = attributes.Connectivity.CrossChunkEdge[cc_layer] + val_dict[col] = layer_out_edges + + r_key = serializers.serialize_uint64(parent_id) + nodes.append(cg.client.mutate_row(r_key, val_dict, time_stamp=time_stamp)) + return nodes + + +def _get_outgoing_edges(node_id, chunk_edges_d, sparse_indices, remapping): + """ + edges of node_id pointing outside the chunk (between and cross) + """ + chunk_out_edges = np.array([], dtype=basetypes.NODE_ID).reshape(0, 2) + for edge_type in remapping: + if node_id in remapping[edge_type]: + edges_obj = chunk_edges_d[edge_type] + edges = edges_obj.get_pairs() + + row_ids, column_ids = sparse_indices[edge_type][ + remapping[edge_type][node_id] + ] + row_ids = row_ids[column_ids == 0] + # edges that this node is part of + chunk_out_edges = np.concatenate([chunk_out_edges, edges[row_ids]]) + return chunk_out_edges diff --git a/pychunkedgraph/ingest/ingestion_utils.py b/pychunkedgraph/ingest/ingestion_utils.py deleted file mode 100644 index 3a3897940..000000000 --- a/pychunkedgraph/ingest/ingestion_utils.py +++ /dev/null @@ -1,81 +0,0 @@ -import numpy as np -from pychunkedgraph.backend import chunkedgraph, chunkedgraph_utils - -import cloudvolume - - -def calc_n_layers(ws_cv, chunk_size, fan_out): - bbox = np.array(ws_cv.bounds.to_list()).reshape(2, 3) - n_chunks = ((bbox[1] - bbox[0]) / chunk_size).astype(np.int) - n_layers = int( np.ceil(chunkedgraph_utils.log_n(np.max(n_chunks), fan_out))) + 2 - return n_layers - - -def initialize_chunkedgraph(cg_table_id, ws_cv_path, chunk_size, size, - cg_mesh_dir, use_skip_connections=True, - s_bits_atomic_layer=None, - n_bits_root_counter=8, fan_out=2, instance_id=None, - project_id=None): - """ Initalizes a chunkedgraph on BigTable - - :param cg_table_id: str - name of chunkedgraph - :param ws_cv_path: str - path to watershed segmentation on Google Cloud - :param chunk_size: np.ndarray - array of three ints - :param size: np.ndarray - array of three ints - :param cg_mesh_dir: str - mesh folder name - :param s_bits_atomic_layer: int or None - number of bits for each x, y and z on the lower layer - :param n_bits_root_counter: int or None - number of bits for counters in root layer - :param fan_out: int - fan out of chunked graph (2 == Octree) - :param instance_id: str - Google instance id - :param project_id: str - Google project id - :return: ChunkedGraph - """ - ws_cv = cloudvolume.CloudVolume(ws_cv_path) - - n_layers_agg = calc_n_layers(ws_cv, chunk_size, fan_out=2) - - if size is not None: - size = np.array(size) - - for i in range(len(ws_cv.info['scales'])): - original_size = ws_cv.info['scales'][i]['size'] - size = np.min([size, original_size], axis=0) - ws_cv.info['scales'][i]['size'] = [int(x) for x in size] - size[:-1] //= 2 - - n_layers_cg = calc_n_layers(ws_cv, chunk_size, fan_out=fan_out) - - dataset_info = ws_cv.info - dataset_info["mesh"] = cg_mesh_dir - dataset_info["data_dir"] = ws_cv_path - dataset_info["graph"] = {"chunk_size": [int(s) for s in chunk_size]} - - kwargs = {"table_id": cg_table_id, - "chunk_size": chunk_size, - "fan_out": np.uint64(fan_out), - "n_layers": np.uint64(n_layers_cg), - "dataset_info": dataset_info, - "use_skip_connections": use_skip_connections, - "s_bits_atomic_layer": s_bits_atomic_layer, - "n_bits_root_counter": n_bits_root_counter, - "is_new": True} - - if instance_id is not None: - kwargs["instance_id"] = instance_id - - if project_id is not None: - kwargs["project_id"] = project_id - - cg = chunkedgraph.ChunkedGraph(**kwargs) - - return cg, n_layers_agg diff --git a/pychunkedgraph/ingest/ingestionmanager.py b/pychunkedgraph/ingest/ingestionmanager.py deleted file mode 100644 index 55b132960..000000000 --- a/pychunkedgraph/ingest/ingestionmanager.py +++ /dev/null @@ -1,79 +0,0 @@ -import itertools -import numpy as np - -from pychunkedgraph.backend import chunkedgraph - - -class IngestionManager(object): - def __init__(self, storage_path, cg_table_id=None, n_layers=None, - instance_id=None, project_id=None): - self._storage_path = storage_path - self._cg_table_id = cg_table_id - self._instance_id = instance_id - self._project_id = project_id - self._cg = None - self._n_layers = n_layers - - @property - def storage_path(self): - return self._storage_path - - @property - def cg(self): - if self._cg is None: - kwargs = {} - - if self._instance_id is not None: - kwargs["instance_id"] = self._instance_id - - if self._project_id is not None: - kwargs["project_id"] = self._project_id - - self._cg = chunkedgraph.ChunkedGraph(table_id=self._cg_table_id, - **kwargs) - - return self._cg - - @property - def bounds(self): - bounds = self.cg.vx_vol_bounds.copy() - bounds -= self.cg.vx_vol_bounds[:, 0:1] - - return bounds - - @property - def chunk_id_bounds(self): - return np.ceil((self.bounds / self.cg.chunk_size[:, None])).astype(np.int) - - @property - def chunk_coord_gen(self): - return itertools.product(*[range(*r) for r in self.chunk_id_bounds]) - - @property - def chunk_coords(self): - return np.array(list(self.chunk_coord_gen), dtype=np.int) - - @property - def n_layers(self): - if self._n_layers is None: - self._n_layers = self.cg.n_layers - return self._n_layers - - def get_serialized_info(self): - info = {"storage_path": self.storage_path, - "cg_table_id": self._cg_table_id, - "n_layers": self.n_layers, - "instance_id": self._instance_id, - "project_id": self._project_id} - - return info - - def is_out_of_bounce(self, chunk_coordinate): - if np.any(chunk_coordinate < 0): - return True - - if np.any(chunk_coordinate > 2**self.cg.bitmasks[1]): - return True - - return False - diff --git a/pychunkedgraph/ingest/manager.py b/pychunkedgraph/ingest/manager.py new file mode 100644 index 000000000..f5f870810 --- /dev/null +++ b/pychunkedgraph/ingest/manager.py @@ -0,0 +1,56 @@ +import pickle + +from . import IngestConfig +from ..graph.meta import ChunkedGraphMeta +from ..graph.chunkedgraph import ChunkedGraph +from ..utils.redis import keys as r_keys +from ..utils.redis import get_rq_queue +from ..utils.redis import get_redis_connection + + +class IngestionManager: + def __init__(self, config: IngestConfig, chunkedgraph_meta: ChunkedGraphMeta): + self._config = config + self._chunkedgraph_meta = chunkedgraph_meta + self._cg = None + self._redis = None + self._task_queues = {} + self.redis # initiate and cache info + + @property + def config(self): + return self._config + + @property + def cg_meta(self): + return self._chunkedgraph_meta + + @property + def cg(self): + if self._cg is None: + self._cg = ChunkedGraph(meta=self.cg_meta) + return self._cg + + @property + def redis(self): + if self._redis is not None: + return self._redis + self._redis = get_redis_connection() + self._redis.set(r_keys.INGESTION_MANAGER, self.serialized(pickled=True)) + return self._redis + + def serialized(self, pickled=False): + params = {"config": self._config, "chunkedgraph_meta": self._chunkedgraph_meta} + if pickled: + return pickle.dumps(params) + return params + + @classmethod + def from_pickle(cls, serialized_info): + return cls(**pickle.loads(serialized_info)) + + def get_task_queue(self, q_name): + if q_name in self._task_queues: + return self._task_queues[q_name] + self._task_queues[q_name] = get_rq_queue(q_name) + return self._task_queues[q_name] diff --git a/pychunkedgraph/ingest/ran_agglomeration.py b/pychunkedgraph/ingest/ran_agglomeration.py new file mode 100644 index 000000000..7c4af51f7 --- /dev/null +++ b/pychunkedgraph/ingest/ran_agglomeration.py @@ -0,0 +1,427 @@ +# pylint: disable=invalid-name, missing-function-docstring +""" +plugin to read agglomeration data provided by Ran Lu +""" + +from collections import defaultdict +from itertools import product +from typing import Dict +from typing import Iterable +from typing import Tuple +from typing import Union +from binascii import crc32 + + +import pandas as pd +import networkx as nx +import numpy as np +import numpy.lib.recfunctions as rfn +from cloudfiles import CloudFiles + +from .manager import IngestionManager +from .utils import postprocess_edge_data +from ..io.edges import put_chunk_edges +from ..io.components import put_chunk_components +from ..graph.utils import basetypes +from ..graph.edges import Edges +from ..graph.edges import EDGE_TYPES +from ..graph.types import empty_2d +from ..graph.chunks.utils import get_chunk_id + +# see section below for description +CRC_LEN = 4 +VERSION_LEN = 4 +HEADER_LEN = 20 + +""" +Agglomeration data is now sharded. +Remap files and the region graph files are merged together +into bigger files with the following structure. + +For example, "in_chunk_xxx_yyy.data" files are merge into a single "in_chunk_xxx.data", +and the reader needs to find out the range to extract the data for each chunk. + +The layout of the new files is like this: + + byte 1-4: 'SO01' (version identifier) + byte 5-12: Offset of the index information + byte 13-20: Length of the index information (including crc32) + byte 21-n: Payload data of the first chunk + byte (n+1)-(n+4): Crc32 of the remap data of first chunk + ... + ... + ... + byte m-l: index data: (chunkid, offset, length)*k + byte (l+1)-(l+4): Crc32 of the index data +""" + + +def read_raw_edge_data(imanager, coord) -> Dict: + edge_dict = _collect_edge_data(imanager, coord) + edge_dict = postprocess_edge_data(imanager, edge_dict) + + # flag to check if chunk has edges + # avoid writing to cloud storage if there are no edges + # unnecessary write operation + no_edges = True + chunk_edges = {} + for edge_type in EDGE_TYPES: + if not edge_dict[edge_type]: + chunk_edges[edge_type] = Edges(np.array([]), np.array([])) + continue + sv_ids1 = edge_dict[edge_type]["sv1"] + sv_ids2 = edge_dict[edge_type]["sv2"] + areas = np.ones(len(sv_ids1)) + affinities = float("inf") * areas + if not edge_type == EDGE_TYPES.cross_chunk: + affinities = edge_dict[edge_type]["aff"] + areas = edge_dict[edge_type]["area"] + + chunk_edges[edge_type] = Edges( + sv_ids1, sv_ids2, affinities=affinities, areas=areas + ) + no_edges = no_edges and not sv_ids1.size + if not no_edges and imanager.cg_meta.data_source.EDGES: + put_chunk_edges(imanager.cg_meta.data_source.EDGES, coord, chunk_edges, 17) + return chunk_edges + + +def _get_cont_chunk_coords(imanager, chunk_coord_a, chunk_coord_b): + """Computes chunk coordinates that compute data between the named chunks.""" + diff = chunk_coord_a - chunk_coord_b + dir_dim = np.where(diff != 0)[0] + assert len(dir_dim) == 1 + dir_dim = dir_dim[0] + + chunk_coord_l = chunk_coord_a if diff[dir_dim] > 0 else chunk_coord_b + c_chunk_coords = [] + for dx, dy, dz in product([0, -1], [0, -1], [0, -1]): + if dz == dy == dx == 0: + continue + if [dx, dy, dz][dir_dim] == 0: + continue + + c_chunk_coord = chunk_coord_l + np.array([dx, dy, dz]) + if imanager.cg_meta.is_out_of_bounds(c_chunk_coord): + continue + c_chunk_coords.append(c_chunk_coord) + return c_chunk_coords + + +def _get_index(cf: CloudFiles, filenames: Iterable[str], inchunk_or_agg: bool) -> dict: + header_range = {"start": 0, "end": HEADER_LEN} + finfos = [] + for fname in filenames: + finfo = {"path": fname} + finfo.update(header_range) + finfos.append(finfo) + + headers = cf.get(finfos, raw=True) + index_infos = [] + for header in headers: + content = header["content"] + if content is None: + continue + content = content[VERSION_LEN:] + idx_offset, idx_length = np.frombuffer(content, dtype=np.uint64) + index_info = { + "path": header["path"], + "start": idx_offset, + "end": idx_offset + idx_length, + } + index_infos.append(index_info) + + files_index = {} + index_datas = cf.get(index_infos, raw=True) + for index_data in index_datas: + content = index_data["content"] + index, crc = content[:-CRC_LEN], content[-CRC_LEN:] + crc = np.frombuffer(crc, dtype=np.uint32)[0] + assert crc32(index) == crc + + dt = np.dtype([("chunkid", "u8"), ("offset", "u8"), ("size", "u8")]) + if inchunk_or_agg is False: + dt = np.dtype([("chunkid", "2u8"), ("offset", "u8"), ("size", "u8")]) + files_index[index_data["path"]] = np.frombuffer(index, dtype=dt) + return files_index + + +def _crc_check(payload: bytes) -> None: + payload_crc32 = np.frombuffer(payload[-CRC_LEN:], dtype=np.uint32) + assert np.frombuffer(payload_crc32, dtype=np.uint32)[0] == crc32(payload[:-CRC_LEN]) + + +def _parse_edge_payloads(payloads, edge_dtype): + result = [] + for payload in payloads: + content = payload["content"] + if content is None: + continue + _crc_check(content) + result.append(np.frombuffer(content[:-CRC_LEN], dtype=edge_dtype)) + return result + + +def _read_in_chunk_files( + chunk_id: basetypes.NODE_ID, + path: str, + filenames: Iterable[str], + edge_dtype: Iterable[Tuple], +): + cf = CloudFiles(path) + files_index = _get_index(cf, filenames, inchunk_or_agg=True) + + finfos = [] + for fname, index in files_index.items(): + for chunk in index: + if chunk["chunkid"] == chunk_id: + finfo = {"path": fname} + finfo["start"] = chunk["offset"] + finfo["end"] = chunk["offset"] + chunk["size"] + finfos.append(finfo) + + payloads = cf.get(finfos, raw=True) + return _parse_edge_payloads(payloads, edge_dtype) + + +def _read_between_or_fake_chunk_files( + chunk_id: basetypes.NODE_ID, + adjacent_id: basetypes.NODE_ID, + path: str, + filenames: Iterable[str], + edge_dtype: Iterable[Tuple], +): + cf = CloudFiles(path) + files_index = _get_index(cf, filenames, inchunk_or_agg=False) + + chunk_finfos = [] + adj_chunk_finfos = [] + for fname, index in files_index.items(): + for chunk in index: + chunk0, chunk1 = chunk["chunkid"][0], chunk["chunkid"][1] + if chunk0 == chunk_id and chunk1 == adjacent_id: + finfo = {"path": fname} + finfo["start"] = chunk["offset"] + finfo["end"] = chunk["offset"] + chunk["size"] + chunk_finfos.append(finfo) + if chunk0 == adjacent_id and chunk1 == chunk_id: + finfo = {"path": fname} + finfo["start"] = chunk["offset"] + finfo["end"] = chunk["offset"] + chunk["size"] + adj_chunk_finfos.append(finfo) + + result = [] + chunk_payloads = cf.get(chunk_finfos, raw=True) + adj_chunk_payloads = cf.get(adj_chunk_finfos, raw=True) + result = _parse_edge_payloads(chunk_payloads, edge_dtype=edge_dtype) + + dtype = [edge_dtype[1], edge_dtype[0]] + edge_dtype[2:] + adj_result = _parse_edge_payloads(adj_chunk_payloads, edge_dtype=dtype) + return result + adj_result + + +def _collect_edge_data(imanager: IngestionManager, chunk_coord): + """Loads edge for single chunk.""" + cg_meta = imanager.cg_meta + edge_dtype = cg_meta.edge_dtype + subfolder = "chunked_rg" + path = f"{imanager.config.AGGLOMERATION}/{subfolder}/" + chunk_coord = np.array(chunk_coord) + x, y, z = chunk_coord + chunk_id = get_chunk_id(cg_meta, layer=1, x=x, y=y, z=z) + + edge_data = defaultdict(list) + in_fnames = [] + x, y, z = chunk_coord + for _x, _y, _z in product([x - 1, x], [y - 1, y], [z - 1, z]): + if cg_meta.is_out_of_bounds(np.array([_x, _y, _z])): + continue + filename = f"in_chunk_0_{_x}_{_y}_{_z}.data" + in_fnames.append(filename) + + edge_data[EDGE_TYPES.in_chunk] = _read_in_chunk_files( + chunk_id, + path, + in_fnames, + edge_dtype, + ) + for d in [-1, 1]: + for dim in range(3): + diff = np.zeros([3], dtype=int) + diff[dim] = d + adjacent_coord = chunk_coord + diff + x, y, z = adjacent_coord + adjacent_id = get_chunk_id(cg_meta, layer=1, x=x, y=y, z=z) + if cg_meta.is_out_of_bounds(adjacent_coord): + continue + + cont_coords = _get_cont_chunk_coords(imanager, chunk_coord, adjacent_coord) + bt_fnames = [] + cx_fnames = [] + for c_chunk_coord in cont_coords: + x, y, z = c_chunk_coord + filename = f"between_chunks_0_{x}_{y}_{z}.data" + bt_fnames.append(filename) + + # EDGES FROM CUTS OF SVS + filename = f"fake_0_{x}_{y}_{z}.data" + cx_fnames.append(filename) + + for edge_type, fnames in [ + (EDGE_TYPES.between_chunk, bt_fnames), + (EDGE_TYPES.cross_chunk, cx_fnames), + ]: + _data = _read_between_or_fake_chunk_files( + chunk_id, + adjacent_id, + path, + fnames, + edge_dtype, + ) + edge_data[edge_type].extend(_data) + + for k in EDGE_TYPES: + if not edge_data[k]: + continue + edge_data[k] = rfn.stack_arrays(edge_data[k], usemask=False) + edge_data_df = pd.DataFrame(edge_data[k]) + edge_data_dfg = ( + edge_data_df.groupby(["sv1", "sv2"]).aggregate(np.sum).reset_index() + ) + edge_data[k] = edge_data_dfg.to_records() + return edge_data + + +def get_active_edges(edges_d, mapping): + active_edges_flag_d, isolated_ids = define_active_edges(edges_d, mapping) + chunk_edges_active = {} + pseudo_isolated_ids = [isolated_ids] + for edge_type in EDGE_TYPES: + edges = edges_d[edge_type] + + active = ( + np.ones(len(edges), dtype=bool) + if edge_type == EDGE_TYPES.cross_chunk + else active_edges_flag_d[edge_type] + ) + + sv_ids1 = edges.node_ids1[active] + sv_ids2 = edges.node_ids2[active] + affinities = edges.affinities[active] + areas = edges.areas[active] + chunk_edges_active[edge_type] = Edges( + sv_ids1, sv_ids2, affinities=affinities, areas=areas + ) + # assume all ids within the chunk are isolated + # to make sure all end up in connected components + pseudo_isolated_ids.append(edges.node_ids1) + if edge_type == EDGE_TYPES.in_chunk: + pseudo_isolated_ids.append(edges.node_ids2) + + return chunk_edges_active, np.unique(np.concatenate(pseudo_isolated_ids)) + + +def define_active_edges(edge_dict, mapping) -> Union[Dict, np.ndarray]: + """Labels edges as within or across segments and extracts isolated ids + :return: dict of np.ndarrays, np.ndarray + bool arrays; True: connected (within same segment) + isolated node ids + """ + mapping_vec = np.vectorize(lambda k: mapping.get(k, -1)) + active = {} + isolated = [[]] + for k in edge_dict: + if len(edge_dict[k].node_ids1) > 0: + agg_id_1 = mapping_vec(edge_dict[k].node_ids1) + else: + assert len(edge_dict[k].node_ids2) == 0 + active[k] = np.array([], dtype=bool) + continue + + agg_id_2 = mapping_vec(edge_dict[k].node_ids2) + active[k] = agg_id_1 == agg_id_2 + # Set those with two -1 to False + agg_1_m = agg_id_1 == -1 + agg_2_m = agg_id_2 == -1 + active[k][agg_1_m] = False + + isolated.append(edge_dict[k].node_ids1[agg_1_m]) + if k == EDGE_TYPES.in_chunk: + isolated.append(edge_dict[k].node_ids2[agg_2_m]) + return active, np.unique(np.concatenate(isolated).astype(basetypes.NODE_ID)) + + +def read_raw_agglomeration_data(imanager: IngestionManager, chunk_coord: np.ndarray): + """ + Collects agglomeration information & builds connected component mapping + """ + cg_meta = imanager.cg_meta + subfolder = "remap" + path = f"{imanager.config.AGGLOMERATION}/{subfolder}/" + chunk_coord = np.array(chunk_coord) + x, y, z = chunk_coord + chunk_id = get_chunk_id(cg_meta, layer=1, x=x, y=y, z=z) + + filenames = [] + chunk_ids = [] + for mip_level in range(0, int(cg_meta.layer_count - 1)): + x, y, z = np.array(chunk_coord / 2**mip_level, dtype=int) + filenames.append(f"done_{mip_level}_{x}_{y}_{z}.data") + chunk_ids.append(chunk_id) + + for d in [-1, 1]: + for dim in range(3): + diff = np.zeros([3], dtype=int) + diff[dim] = d + adjacent_coord = chunk_coord + diff + x, y, z = adjacent_coord + adjacent_id = get_chunk_id(cg_meta, layer=1, x=x, y=y, z=z) + + for mip_level in range(0, int(cg_meta.layer_count - 1)): + x, y, z = np.array(adjacent_coord / 2**mip_level, dtype=int) + filenames.append(f"done_{mip_level}_{x}_{y}_{z}.data") + chunk_ids.append(adjacent_id) + + edges_list = _read_agg_files(filenames, chunk_ids, path) + G = nx.Graph() + G.add_edges_from(np.concatenate(edges_list)) + mapping = {} + components = list(nx.connected_components(G)) + for i_cc, cc in enumerate(components): + cc = list(cc) + mapping.update(dict(zip(cc, [i_cc] * len(cc)))) + + if mapping and cg_meta.data_source.COMPONENTS: + put_chunk_components(cg_meta.data_source.COMPONENTS, components, chunk_coord) + return mapping + + +def _read_agg_files(filenames, chunk_ids, path): + cf = CloudFiles(path) + finfos = [] + files_index = _get_index(cf, set(filenames), inchunk_or_agg=True) + + for fname, chunk_id in zip(filenames, chunk_ids): + try: + index = files_index[fname] + except KeyError: + continue + for chunk in index: + if chunk["chunkid"] == chunk_id: + finfo = {"path": fname} + finfo["start"] = chunk["offset"] + finfo["end"] = chunk["offset"] + chunk["size"] + finfos.append(finfo) + break + + edge_list = [empty_2d] + payloads = cf.get(finfos, raw=True) + for payload in payloads: + cont = payload["content"] + if cont is None: + continue + _crc_check(cont) + edges = np.frombuffer(cont[:-CRC_LEN], dtype=basetypes.NODE_ID).reshape(-1, 2) + if edges is not None: + edge_list.append(edges) + return edge_list diff --git a/pychunkedgraph/ingest/ran_ingestion.py b/pychunkedgraph/ingest/ran_ingestion.py deleted file mode 100644 index 7a742d1bf..000000000 --- a/pychunkedgraph/ingest/ran_ingestion.py +++ /dev/null @@ -1,556 +0,0 @@ -import collections -import time -import random - -import pandas as pd -import cloudvolume -import networkx as nx -import numpy as np -import numpy.lib.recfunctions as rfn -import zstandard as zstd -from multiwrapper import multiprocessing_utils as mu - -from pychunkedgraph.ingest import ingestionmanager, ingestion_utils as iu - - -def ingest_into_chunkedgraph(storage_path, ws_cv_path, cg_table_id, - chunk_size=[256, 256, 512], - use_skip_connections=True, - s_bits_atomic_layer=None, - fan_out=2, aff_dtype=np.float32, - size=None, - instance_id=None, project_id=None, - start_layer=1, n_threads=[64, 64]): - """ Creates a chunkedgraph from a Ran Agglomerattion - - :param storage_path: str - Google cloud bucket path (agglomeration) - example: gs://ranl-scratch/minnie_test_2 - :param ws_cv_path: str - Google cloud bucket path (watershed segmentation) - example: gs://microns-seunglab/minnie_v0/minnie10/ws_minnie_test_2/agg - :param cg_table_id: str - chunkedgraph table name - :param fan_out: int - fan out of chunked graph (2 == Octree) - :param aff_dtype: np.dtype - affinity datatype (np.float32 or np.float64) - :param instance_id: str - Google instance id - :param project_id: str - Google project id - :param start_layer: int - :param n_threads: list of ints - number of threads to use - :return: - """ - storage_path = storage_path.strip("/") - ws_cv_path = ws_cv_path.strip("/") - - cg_mesh_dir = f"{cg_table_id}_meshes" - chunk_size = np.array(chunk_size, dtype=np.uint64) - - _, n_layers_agg = iu.initialize_chunkedgraph(cg_table_id=cg_table_id, - ws_cv_path=ws_cv_path, - chunk_size=chunk_size, - size=size, - use_skip_connections=use_skip_connections, - s_bits_atomic_layer=s_bits_atomic_layer, - cg_mesh_dir=cg_mesh_dir, - fan_out=fan_out, - instance_id=instance_id, - project_id=project_id) - - im = ingestionmanager.IngestionManager(storage_path=storage_path, - cg_table_id=cg_table_id, - n_layers=n_layers_agg, - instance_id=instance_id, - project_id=project_id) - - # #TODO: Remove later: - # logging.basicConfig(level=logging.DEBUG) - # im.cg.logger = logging.getLogger(__name__) - # ------------------------------------------ - if start_layer < 3: - create_atomic_chunks(im, aff_dtype=aff_dtype, n_threads=n_threads[0]) - - create_abstract_layers(im, n_threads=n_threads[1], start_layer=start_layer) - - return im - - -def create_abstract_layers(im, start_layer=3, n_threads=1): - """ Creates abstract of chunkedgraph (> 2) - - :param im: IngestionManager - :param n_threads: int - number of threads to use - :return: - """ - if start_layer < 3: - start_layer = 3 - - assert start_layer < int(im.cg.n_layers + 1) - - for layer_id in range(start_layer, int(im.cg.n_layers + 1)): - create_layer(im, layer_id, n_threads=n_threads) - - -def create_layer(im, layer_id, block_size=100, n_threads=1): - """ Creates abstract layer of chunkedgraph - - Abstract layers have to be build in sequence. Abstract layers are all layers - above the first layer (1). `create_atomic_chunks` creates layer 2 as well. - Hence, this function is responsible for every creating layers > 2. - - :param im: IngestionManager - :param layer_id: int - > 2 - :param n_threads: int - number of threads to use - :return: - """ - assert layer_id > 2 - - child_chunk_coords = im.chunk_coords // im.cg.fan_out ** (layer_id - 3) - child_chunk_coords = child_chunk_coords.astype(np.int) - child_chunk_coords = np.unique(child_chunk_coords, axis=0) - - parent_chunk_coords = child_chunk_coords // im.cg.fan_out - parent_chunk_coords = parent_chunk_coords.astype(np.int) - parent_chunk_coords, inds = np.unique(parent_chunk_coords, axis=0, - return_inverse=True) - - im_info = im.get_serialized_info() - multi_args = [] - - # Randomize chunks - order = np.arange(len(parent_chunk_coords), dtype=np.int) - np.random.shuffle(order) - - # Block chunks - block_size = min(block_size, int(np.ceil(len(order) / n_threads / 3))) - n_blocks = int(len(order) / block_size) - blocks = np.array_split(order, n_blocks) - - for i_block, block in enumerate(blocks): - chunks = [] - for idx in block: - chunks.append(child_chunk_coords[inds == idx]) - - multi_args.append([im_info, layer_id, len(order), n_blocks, i_block, - chunks]) - - if n_threads == 1: - mu.multiprocess_func( - _create_layers, multi_args, n_threads=n_threads, - verbose=True, debug=n_threads == 1) - else: - mu.multisubprocess_func(_create_layers, multi_args, n_threads=n_threads, - suffix=f"{layer_id}") - - -def _create_layers(args): - """ Multiprocessing helper for create_layer """ - im_info, layer_id, n_chunks, n_blocks, i_block, chunks = args - im = ingestionmanager.IngestionManager(**im_info) - - for i_chunk, child_chunk_coords in enumerate(chunks): - time_start = time.time() - - im.cg.add_layer(layer_id, child_chunk_coords, n_threads=8, verbose=True) - - print(f"Layer {layer_id} - Job {i_block + 1} / {n_blocks} - " - f"{i_chunk + 1} / {len(chunks)} -- %.3fs" % - (time.time() - time_start)) - - -def create_atomic_chunks(im, aff_dtype=np.float32, n_threads=1, block_size=100): - """ Creates all atomic chunks - - :param im: IngestionManager - :param aff_dtype: np.dtype - affinity datatype (np.float32 or np.float64) - :param n_threads: int - number of threads to use - :return: - """ - - im_info = im.get_serialized_info() - - multi_args = [] - - # Randomize chunk order - chunk_coords = list(im.chunk_coord_gen) - order = np.arange(len(chunk_coords), dtype=np.int) - np.random.shuffle(order) - - # Block chunks - block_size = min(block_size, int(np.ceil(len(chunk_coords) / n_threads / 3))) - n_blocks = int(len(chunk_coords) / block_size) - blocks = np.array_split(order, n_blocks) - - for i_block, block in enumerate(blocks): - chunks = [] - for b_idx in block: - chunks.append(chunk_coords[b_idx]) - - multi_args.append([im_info, aff_dtype, n_blocks, i_block, chunks]) - - if n_threads == 1: - mu.multiprocess_func( - _create_atomic_chunk, multi_args, n_threads=n_threads, - verbose=True, debug=n_threads == 1) - else: - mu.multisubprocess_func( - _create_atomic_chunk, multi_args, n_threads=n_threads) - - -def _create_atomic_chunk(args): - """ Multiprocessing helper for create_atomic_chunks """ - im_info, aff_dtype, n_blocks, i_block, chunks = args - im = ingestionmanager.IngestionManager(**im_info) - - for i_chunk, chunk_coord in enumerate(chunks): - time_start = time.time() - - create_atomic_chunk(im, chunk_coord, aff_dtype=aff_dtype, verbose=True) - - print(f"Layer 1 - {chunk_coord} - Job {i_block + 1} / {n_blocks} - " - f"{i_chunk + 1} / {len(chunks)} -- %.3fs" % - (time.time() - time_start)) - - -def create_atomic_chunk(im, chunk_coord, aff_dtype=np.float32, verbose=True): - """ Creates single atomic chunk - - :param im: IngestionManager - :param chunk_coord: np.ndarray - array of three ints - :param aff_dtype: np.dtype - np.float64 or np.float32 - :param verbose: bool - :return: - """ - chunk_coord = np.array(list(chunk_coord), dtype=np.int) - - edge_dict = collect_edge_data(im, chunk_coord, aff_dtype=aff_dtype) - mapping = collect_agglomeration_data(im, chunk_coord) - active_edge_dict, isolated_ids = define_active_edges(edge_dict, mapping) - - edge_ids = {} - edge_affs = {} - edge_areas = {} - - for k in edge_dict.keys(): - if k == "cross": - edge_ids[k] = np.concatenate([edge_dict[k]["sv1"][:, None], - edge_dict[k]["sv2"][:, None]], - axis=1) - continue - - sv1_conn = edge_dict[k]["sv1"][active_edge_dict[k]] - sv2_conn = edge_dict[k]["sv2"][active_edge_dict[k]] - aff_conn = edge_dict[k]["aff"][active_edge_dict[k]] - area_conn = edge_dict[k]["area"][active_edge_dict[k]] - edge_ids[f"{k}_connected"] = np.concatenate([sv1_conn[:, None], - sv2_conn[:, None]], - axis=1) - edge_affs[f"{k}_connected"] = aff_conn.astype(np.float32) - edge_areas[f"{k}_connected"] = area_conn - - sv1_disconn = edge_dict[k]["sv1"][~active_edge_dict[k]] - sv2_disconn = edge_dict[k]["sv2"][~active_edge_dict[k]] - aff_disconn = edge_dict[k]["aff"][~active_edge_dict[k]] - area_disconn = edge_dict[k]["area"][~active_edge_dict[k]] - edge_ids[f"{k}_disconnected"] = np.concatenate([sv1_disconn[:, None], - sv2_disconn[:, None]], - axis=1) - edge_affs[f"{k}_disconnected"] = aff_disconn.astype(np.float32) - edge_areas[f"{k}_disconnected"] = area_disconn - - im.cg.add_atomic_edges_in_chunks(edge_ids, edge_affs, edge_areas, - isolated_node_ids=isolated_ids) - - return edge_ids, edge_affs, edge_areas - - -def _get_cont_chunk_coords(im, chunk_coord_a, chunk_coord_b): - """ Computes chunk coordinates that compute data between the named chunks - - :param im: IngestionManagaer - :param chunk_coord_a: np.ndarray - array of three ints - :param chunk_coord_b: np.ndarray - array of three ints - :return: np.ndarray - """ - - diff = chunk_coord_a - chunk_coord_b - - dir_dim = np.where(diff != 0)[0] - assert len(dir_dim) == 1 - dir_dim = dir_dim[0] - - if diff[dir_dim] > 0: - chunk_coord_l = chunk_coord_a - else: - chunk_coord_l = chunk_coord_b - - c_chunk_coords = [] - for dx in [-1, 0]: - for dy in [-1, 0]: - for dz in [-1, 0]: - if dz == dy == dx == 0: - continue - - c_chunk_coord = chunk_coord_l + np.array([dx, dy, dz]) - - if [dx, dy, dz][dir_dim] == 0: - continue - - if im.is_out_of_bounce(c_chunk_coord): - continue - - c_chunk_coords.append(c_chunk_coord) - - return c_chunk_coords - - -def collect_edge_data(im, chunk_coord, aff_dtype=np.float32): - """ Loads edge for single chunk - - :param im: IngestionManager - :param chunk_coord: np.ndarray - array of three ints - :param aff_dtype: np.dtype - :return: dict of np.ndarrays - """ - subfolder = "chunked_rg" - - base_path = f"{im.storage_path}/{subfolder}/" - - chunk_coord = np.array(chunk_coord) - - chunk_id = im.cg.get_chunk_id(layer=1, x=chunk_coord[0], y=chunk_coord[1], - z=chunk_coord[2]) - - filenames = collections.defaultdict(list) - swap = collections.defaultdict(list) - for x in [chunk_coord[0] - 1, chunk_coord[0]]: - for y in [chunk_coord[1] - 1, chunk_coord[1]]: - for z in [chunk_coord[2] - 1, chunk_coord[2]]: - - if im.is_out_of_bounce(np.array([x, y, z])): - continue - - # EDGES WITHIN CHUNKS - filename = f"in_chunk_0_{x}_{y}_{z}_{chunk_id}.data" - filenames["in"].append(filename) - - for d in [-1, 1]: - for dim in range(3): - diff = np.zeros([3], dtype=np.int) - diff[dim] = d - - adjacent_chunk_coord = chunk_coord + diff - adjacent_chunk_id = im.cg.get_chunk_id(layer=1, - x=adjacent_chunk_coord[0], - y=adjacent_chunk_coord[1], - z=adjacent_chunk_coord[2]) - - if im.is_out_of_bounce(adjacent_chunk_coord): - continue - - c_chunk_coords = _get_cont_chunk_coords(im, chunk_coord, - adjacent_chunk_coord) - - larger_id = np.max([chunk_id, adjacent_chunk_id]) - smaller_id = np.min([chunk_id, adjacent_chunk_id]) - chunk_id_string = f"{smaller_id}_{larger_id}" - - for c_chunk_coord in c_chunk_coords: - x, y, z = c_chunk_coord - - # EDGES BETWEEN CHUNKS - filename = f"between_chunks_0_{x}_{y}_{z}_{chunk_id_string}.data" - filenames["between"].append(filename) - - swap[filename] = larger_id == chunk_id - - # EDGES FROM CUTS OF SVS - filename = f"fake_0_{x}_{y}_{z}_{chunk_id_string}.data" - filenames["cross"].append(filename) - - swap[filename] = larger_id == chunk_id - - edge_data = {} - read_counter = collections.Counter() - - dtype = [("sv1", np.uint64), ("sv2", np.uint64), - ("aff", aff_dtype), ("area", np.uint64)] - for k in filenames: - # print(k, len(filenames[k])) - - with cloudvolume.Storage(base_path, n_threads=10) as stor: - files = stor.get_files(filenames[k]) - - data = [] - for file in files: - if file["content"] is None: - # print(f"{file['filename']} not created or empty") - continue - - if file["error"] is not None: - # print(f"error reading {file['filename']}") - continue - - if swap[file["filename"]]: - this_dtype = [dtype[1], dtype[0], dtype[2], dtype[3]] - content = np.frombuffer(file["content"], dtype=this_dtype) - else: - content = np.frombuffer(file["content"], dtype=dtype) - - data.append(content) - - read_counter[k] += 1 - - try: - edge_data[k] = rfn.stack_arrays(data, usemask=False) - except: - raise() - - edge_data_df = pd.DataFrame(edge_data[k]) - edge_data_dfg = edge_data_df.groupby(["sv1", "sv2"]).aggregate(np.sum).reset_index() - edge_data[k] = edge_data_dfg.to_records() - - # # TEST - # with cloudvolume.Storage(base_path, n_threads=10) as stor: - # files = list(stor.list_files()) - # - # true_counter = collections.Counter() - # for file in files: - # if str(chunk_id) in file: - # true_counter[file.split("_")[0]] += 1 - # - # print("Truth", true_counter) - # print("Reality", read_counter) - - return edge_data - - -def _read_agg_files(filenames, base_path): - with cloudvolume.Storage(base_path, n_threads=10) as stor: - files = stor.get_files(filenames) - - edge_list = [] - for file in files: - if file["content"] is None: - continue - - if file["error"] is not None: - continue - - content = zstd.ZstdDecompressor().decompressobj().decompress(file["content"]) - edge_list.append(np.frombuffer(content, dtype=np.uint64).reshape(-1, 2)) - - return edge_list - - -def collect_agglomeration_data(im, chunk_coord): - """ Collects agglomeration information & builds connected component mapping - - :param im: IngestionManager - :param chunk_coord: np.ndarray - array of three ints - :return: dictionary - """ - subfolder = "remap" - base_path = f"{im.storage_path}/{subfolder}/" - - chunk_coord = np.array(chunk_coord) - - chunk_id = im.cg.get_chunk_id(layer=1, x=chunk_coord[0], y=chunk_coord[1], - z=chunk_coord[2]) - - filenames = [] - for mip_level in range(0, int(im.n_layers - 1)): - x, y, z = np.array(chunk_coord / 2 ** mip_level, dtype=np.int) - filenames.append(f"done_{mip_level}_{x}_{y}_{z}_{chunk_id}.data.zst") - - for d in [-1, 1]: - for dim in range(3): - diff = np.zeros([3], dtype=np.int) - diff[dim] = d - - adjacent_chunk_coord = chunk_coord + diff - - adjacent_chunk_id = im.cg.get_chunk_id(layer=1, - x=adjacent_chunk_coord[0], - y=adjacent_chunk_coord[1], - z=adjacent_chunk_coord[2]) - - for mip_level in range(0, int(im.n_layers - 1)): - x, y, z = np.array(adjacent_chunk_coord / 2 ** mip_level, dtype=np.int) - filenames.append(f"done_{mip_level}_{x}_{y}_{z}_{adjacent_chunk_id}.data.zst") - - # print(filenames) - edge_list = _read_agg_files(filenames, base_path) - - edges = np.concatenate(edge_list) - - G = nx.Graph() - G.add_edges_from(edges) - ccs = nx.connected_components(G) - - mapping = {} - for i_cc, cc in enumerate(ccs): - cc = list(cc) - mapping.update(dict(zip(cc, [i_cc] * len(cc)))) - - return mapping - - -def define_active_edges(edge_dict, mapping): - """ Labels edges as within or across segments and extracts isolated ids - - :param edge_dict: dict of np.ndarrays - :param mapping: dict - :return: dict of np.ndarrays, np.ndarray - bool arrays; True: connected (within same segment) - isolated node ids - """ - def _mapping_default(key): - if key in mapping: - return mapping[key] - else: - return -1 - - mapping_vec = np.vectorize(_mapping_default) - - active = {} - isolated = [[]] - for k in edge_dict: - if len(edge_dict[k]["sv1"]) > 0: - agg_id_1 = mapping_vec(edge_dict[k]["sv1"]) - else: - assert len(edge_dict[k]["sv2"]) == 0 - active[k] = np.array([], dtype=np.bool) - continue - - agg_id_2 = mapping_vec(edge_dict[k]["sv2"]) - - active[k] = agg_id_1 == agg_id_2 - - # Set those with two -1 to False - agg_1_m = agg_id_1 == -1 - agg_2_m = agg_id_2 == -1 - active[k][agg_1_m] = False - - isolated.append(edge_dict[k]["sv1"][agg_1_m]) - - if k == "in": - isolated.append(edge_dict[k]["sv2"][agg_2_m]) - - return active, np.unique(np.concatenate(isolated).astype(np.uint64)) - diff --git a/pychunkedgraph/ingest/rq_cli.py b/pychunkedgraph/ingest/rq_cli.py new file mode 100644 index 000000000..27b9c865d --- /dev/null +++ b/pychunkedgraph/ingest/rq_cli.py @@ -0,0 +1,138 @@ +""" +cli for redis jobs +""" +import os +import sys + +import click +from redis import Redis +from rq import Queue +from rq import Worker +from rq.worker import WorkerStatus +from rq.job import Job +from rq.exceptions import InvalidJobOperationError +from rq.exceptions import NoSuchJobError +from rq.registry import StartedJobRegistry +from rq.registry import FailedJobRegistry +from flask import current_app +from flask.cli import AppGroup + +from ..utils.redis import REDIS_HOST +from ..utils.redis import REDIS_PORT +from ..utils.redis import REDIS_PASSWORD + + +# rq extended +rq_cli = AppGroup("rq") +connection = Redis(host=REDIS_HOST, port=REDIS_PORT, db=0, password=REDIS_PASSWORD) + + +@rq_cli.command("status") +@click.argument("queues", nargs=-1, type=str) +@click.option("--show-busy", is_flag=True) +def get_status(queues, show_busy): + print("NOTE: Use --show-busy to display count of non idle workers\n") + for queue in queues: + q = Queue(queue, connection=connection) + print(f"Queue name \t: {queue}") + print(f"Jobs queued \t: {len(q)}") + print(f"Workers total \t: {Worker.count(queue=q)}") + if show_busy: + workers = Worker.all(queue=q) + count = sum([worker.get_state() == WorkerStatus.BUSY for worker in workers]) + print(f"Workers busy \t: {count}") + print(f"Jobs failed \t: {q.failed_job_registry.count}\n") + + +@rq_cli.command("failed") +@click.argument("queue", type=str) +@click.argument("job_ids", nargs=-1) +def failed_jobs(queue, job_ids): + if job_ids: + for job_id in job_ids: + j = Job.fetch(job_id, connection=connection) + print(f"JOB ID {job_id}") + print("KWARGS") + print(j.kwargs) + print("\nARGS") + print(j.args) + print("\nEXCEPTION") + print(j.exc_info) + else: + q = Queue(queue, connection=connection) + ids = q.failed_job_registry.get_job_ids() + print("\n".join(ids)) + + +@rq_cli.command("empty") +@click.argument("queue", type=str) +def empty_queue(queue): + q = Queue(queue, connection=connection) + job_count = len(q) + q.empty() + print(f"{job_count} jobs removed from {queue}.") + + +@rq_cli.command("reenqueue") +@click.argument("queue", type=str) +@click.argument("job_ids", nargs=-1, required=True) +def enqueue(queue, job_ids): + """Enqueues *existing* jobs that are stuck for whatever reason.""" + q = Queue(queue, connection=connection) + for job_id in job_ids: + q.push_job_id(job_id) + + +@rq_cli.command("requeue") +@click.argument("queue", type=str) +@click.option("--all", "-a", is_flag=True, help="Requeue all failed jobs") +@click.argument("job_ids", nargs=-1) +def requeue(queue, all, job_ids): + """Requeue failed jobs.""" + failed_job_registry = FailedJobRegistry(queue, connection=connection) + if all: + job_ids = failed_job_registry.get_job_ids() + + if not job_ids: + click.echo("Nothing to do") + sys.exit(0) + + click.echo(f"Requeueing {len(job_ids)} jobs from failed queue") + fail_count = 0 + for job_id in job_ids: + try: + failed_job_registry.requeue(job_id) + except (InvalidJobOperationError, NoSuchJobError): + fail_count += 1 + + if fail_count > 0: + click.secho( + f"Unable to requeue {fail_count} jobs from failed job registry", fg="red" + ) + + +@rq_cli.command("cleanup") +@click.argument("queue", type=str) +def clean_start_registry(queue): + """ + Clean started job registry + Sometimes started jobs are not moved to failed registry (network issues) + This command takes the jobs off the started registry and reueues them + """ + registry = StartedJobRegistry(name=queue, connection=connection) + cleaned_jobs = registry.cleanup() + print(f"Requeued {len(cleaned_jobs)} jobs from the started job registry.") + + +@rq_cli.command("clear_failed") +@click.argument("queue", type=str) +def clear_failed_registry(queue): + failed_job_registry = FailedJobRegistry(queue, connection=connection) + job_ids = failed_job_registry.get_job_ids() + for job_id in job_ids: + failed_job_registry.remove(job_id, delete_job=True) + print(f"Deleted {len(job_ids)} jobs from the failed job registry.") + + +def init_rq_cmds(app): + app.cli.add_command(rq_cli) diff --git a/pychunkedgraph/ingest/utils.py b/pychunkedgraph/ingest/utils.py new file mode 100644 index 000000000..fa7ef7a3c --- /dev/null +++ b/pychunkedgraph/ingest/utils.py @@ -0,0 +1,75 @@ +from typing import Tuple + + +from . import ClusterIngestConfig +from . import IngestConfig +from ..graph.meta import ChunkedGraphMeta +from ..graph.meta import DataSource +from ..graph.meta import GraphConfig + +from ..graph.client import BackendClientInfo +from ..graph.client.bigtable import BigTableConfig + +chunk_id_str = lambda layer, coords: f"{layer}_{'_'.join(map(str, coords))}" + + +def bootstrap( + graph_id: str, + config: dict, + overwrite: bool = False, + raw: bool = False, + test_run: bool = False, +) -> Tuple[ChunkedGraphMeta, IngestConfig, BackendClientInfo]: + """Parse config loaded from a yaml file.""" + ingest_config = IngestConfig( + **config.get("ingest_config", {}), + CLUSTER=ClusterIngestConfig(), + USE_RAW_EDGES=raw, + USE_RAW_COMPONENTS=raw, + TEST_RUN=test_run, + ) + client_config = BigTableConfig(**config["backend_client"]["CONFIG"]) + client_info = BackendClientInfo(config["backend_client"]["TYPE"], client_config) + + graph_config = GraphConfig( + ID=f"{graph_id}", + OVERWRITE=overwrite, + **config["graph_config"], + ) + data_source = DataSource(**config["data_source"]) + + meta = ChunkedGraphMeta(graph_config, data_source) + return (meta, ingest_config, client_info) + + +def postprocess_edge_data(im, edge_dict): + data_version = im.cg_meta.data_source.DATA_VERSION + if data_version == 2: + return edge_dict + elif data_version in [3, 4]: + new_edge_dict = {} + for k in edge_dict: + new_edge_dict[k] = {} + if edge_dict[k] is None or len(edge_dict[k]) == 0: + continue + + areas = ( + edge_dict[k]["area_x"] * im.cg_meta.resolution[0] + + edge_dict[k]["area_y"] * im.cg_meta.resolution[1] + + edge_dict[k]["area_z"] * im.cg_meta.resolution[2] + ) + + affs = ( + edge_dict[k]["aff_x"] * im.cg_meta.resolution[0] + + edge_dict[k]["aff_y"] * im.cg_meta.resolution[1] + + edge_dict[k]["aff_z"] * im.cg_meta.resolution[2] + ) + + new_edge_dict[k]["sv1"] = edge_dict[k]["sv1"] + new_edge_dict[k]["sv2"] = edge_dict[k]["sv2"] + new_edge_dict[k]["area"] = areas + new_edge_dict[k]["aff"] = affs + + return new_edge_dict + else: + raise Exception(f"Unknown data_version: {data_version}") diff --git a/pychunkedgraph/io/__init__.py b/pychunkedgraph/io/__init__.py new file mode 100644 index 000000000..60fe4ebb1 --- /dev/null +++ b/pychunkedgraph/io/__init__.py @@ -0,0 +1,3 @@ +""" +All secondary (slow) storage stuff must go here +""" \ No newline at end of file diff --git a/pychunkedgraph/io/components.py b/pychunkedgraph/io/components.py new file mode 100644 index 000000000..a6301c7d2 --- /dev/null +++ b/pychunkedgraph/io/components.py @@ -0,0 +1,59 @@ +from typing import Dict, Iterable + +import numpy as np +from cloudfiles import CloudFiles + +from .protobuf.chunkComponents_pb2 import ChunkComponentsMsg +from ..graph.utils import basetypes + + +def serialize(connected_components: Iterable) -> ChunkComponentsMsg: + components = [] + for component in list(connected_components): + component = np.array(list(component), dtype=basetypes.NODE_ID) + components.append(np.array([len(component)], dtype=basetypes.NODE_ID)) + components.append(component) + components_message = ChunkComponentsMsg() + components_message.components[:] = np.concatenate(components) + return components_message + + +def deserialize(components_message: ChunkComponentsMsg) -> Dict: + mapping = {} + components = np.array(components_message.components, basetypes.NODE_ID) + idx = 0 + n_components = 0 + while idx < components.size: + component_size = int(components[idx]) + start = idx + 1 + component = components[start : start + component_size] + mapping.update(dict(zip(component, [n_components] * component_size))) + idx += component_size + 1 + n_components += 1 + return mapping + + +def put_chunk_components(components_dir, components, chunk_coord) -> None: + # filename format - components_x_y_z.serliazation + components_message = serialize(components) + filename = f"components_{'_'.join(str(coord) for coord in chunk_coord)}.proto" + cf = CloudFiles(components_dir) + cf.put( + filename, + content=components_message.SerializeToString(), + compress=None, + cache_control="no-cache", + ) + + +def get_chunk_components(components_dir, chunk_coord) -> Dict: + # filename format - components_x_y_z.serliazation + filename = f"components_{'_'.join(str(coord) for coord in chunk_coord)}.proto" + + cf = CloudFiles(components_dir) + content = cf.get(filename) + if not content: + return {} + components_message = ChunkComponentsMsg() + components_message.ParseFromString(content) + return deserialize(components_message) diff --git a/pychunkedgraph/io/edges.py b/pychunkedgraph/io/edges.py new file mode 100644 index 000000000..82595e139 --- /dev/null +++ b/pychunkedgraph/io/edges.py @@ -0,0 +1,105 @@ +# pylint: disable=invalid-name, missing-docstring +""" +Functions for reading and writing edges from cloud storage. +""" +import os +from typing import Dict +from typing import List +from typing import Tuple + +import numpy as np +import zstandard as zstd +from cloudfiles import CloudFiles + +from .protobuf.chunkEdges_pb2 import EdgesMsg +from .protobuf.chunkEdges_pb2 import ChunkEdgesMsg +from ..graph.edges import Edges +from ..graph.edges import EDGE_TYPES +from ..graph.utils import basetypes +from ..graph.edges.utils import concatenate_chunk_edges + + +def serialize(edges: Edges) -> EdgesMsg: + edges_proto = EdgesMsg() + edges_proto.node_ids1 = edges.node_ids1.astype(basetypes.NODE_ID).tobytes() + edges_proto.node_ids2 = edges.node_ids2.astype(basetypes.NODE_ID).tobytes() + edges_proto.affinities = edges.affinities.astype(basetypes.EDGE_AFFINITY).tobytes() + edges_proto.areas = edges.areas.astype(basetypes.EDGE_AREA).tobytes() + return edges_proto + + +def deserialize(edges_message: EdgesMsg) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + sv_ids1 = np.frombuffer(edges_message.node_ids1, basetypes.NODE_ID) + sv_ids2 = np.frombuffer(edges_message.node_ids2, basetypes.NODE_ID) + affinities = np.frombuffer(edges_message.affinities, basetypes.EDGE_AFFINITY) + areas = np.frombuffer(edges_message.areas, basetypes.EDGE_AREA) + return Edges(sv_ids1, sv_ids2, affinities=affinities, areas=areas) + + +def _parse_edges(compressed: List[bytes]) -> List[Dict]: + result = [] + if(len(compressed) == 0): + return result + zdc = zstd.ZstdDecompressor() + try: + n_threads = int(os.environ.get("ZSTD_THREADS", 1)) + except ValueError: + n_threads = 1 + + decompressed = [] + try: + decompressed = zdc.multi_decompress_to_buffer(compressed, threads=n_threads) + except ValueError: + for content in compressed: + decompressed.append(zdc.decompressobj().decompress(content)) + + for content in decompressed: + chunk_edges = ChunkEdgesMsg() + chunk_edges.ParseFromString(memoryview(content)) + edges_dict = {} + edges_dict[EDGE_TYPES.in_chunk] = deserialize(chunk_edges.in_chunk) + edges_dict[EDGE_TYPES.between_chunk] = deserialize(chunk_edges.between_chunk) + edges_dict[EDGE_TYPES.cross_chunk] = deserialize(chunk_edges.cross_chunk) + result.append(edges_dict) + return result + + +def get_chunk_edges(edges_dir: str, chunks_coordinates: List[np.ndarray]) -> Dict: + """Read edges from GCS.""" + fnames = [] + for chunk_coords in chunks_coordinates: + chunk_str = "_".join(str(coord) for coord in chunk_coords) + # filename format - edges_x_y_z.serialization.compression + fnames.append(f"edges_{chunk_str}.proto.zst") + + cf = CloudFiles(edges_dir, num_threads=4) + files = cf.get(fnames, raw=True) + compressed = [] + for f in files: + if not f["content"]: + continue + compressed.append(f["content"]) + return concatenate_chunk_edges(_parse_edges(compressed)) + + +def put_chunk_edges( + edges_dir: str, chunk_coordinates: np.ndarray, edges_d, compression_level: int +) -> None: + """Write edges to GCS.""" + chunk_edges = ChunkEdgesMsg() + chunk_edges.in_chunk.CopyFrom(serialize(edges_d[EDGE_TYPES.in_chunk])) + chunk_edges.between_chunk.CopyFrom(serialize(edges_d[EDGE_TYPES.between_chunk])) + chunk_edges.cross_chunk.CopyFrom(serialize(edges_d[EDGE_TYPES.cross_chunk])) + + cctx = zstd.ZstdCompressor(level=compression_level) + chunk_str = "_".join(str(coord) for coord in chunk_coordinates) + + # filename format - edges_x_y_z.serialization.compression + filename = f"edges_{chunk_str}.proto.zst" + cf = CloudFiles(edges_dir) + cf.put( + filename, + content=cctx.compress(chunk_edges.SerializeToString()), + compress=None, + cache_control="no-cache", + ) diff --git a/pychunkedgraph/rechunking/__init__.py b/pychunkedgraph/io/protobuf/__init__.py similarity index 100% rename from pychunkedgraph/rechunking/__init__.py rename to pychunkedgraph/io/protobuf/__init__.py diff --git a/pychunkedgraph/io/protobuf/chunkComponents.proto b/pychunkedgraph/io/protobuf/chunkComponents.proto new file mode 100644 index 000000000..f415c3970 --- /dev/null +++ b/pychunkedgraph/io/protobuf/chunkComponents.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package components; + +message ChunkComponentsMsg { + repeated uint64 components = 1; +} \ No newline at end of file diff --git a/pychunkedgraph/io/protobuf/chunkComponents_pb2.py b/pychunkedgraph/io/protobuf/chunkComponents_pb2.py new file mode 100644 index 000000000..19a3914e8 --- /dev/null +++ b/pychunkedgraph/io/protobuf/chunkComponents_pb2.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: chunkComponents.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x63hunkComponents.proto\x12\ncomponents\"(\n\x12\x43hunkComponentsMsg\x12\x12\n\ncomponents\x18\x01 \x03(\x04\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'chunkComponents_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _CHUNKCOMPONENTSMSG._serialized_start=37 + _CHUNKCOMPONENTSMSG._serialized_end=77 +# @@protoc_insertion_point(module_scope) diff --git a/pychunkedgraph/io/protobuf/chunkEdges.proto b/pychunkedgraph/io/protobuf/chunkEdges.proto new file mode 100644 index 000000000..377e182c0 --- /dev/null +++ b/pychunkedgraph/io/protobuf/chunkEdges.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package edges; + +message EdgesMsg { + bytes node_ids1 = 1; + bytes node_ids2 = 2; + bytes affinities = 3; + bytes areas = 4; +} + +message ChunkEdgesMsg { + EdgesMsg in_chunk = 1; + EdgesMsg cross_chunk = 2; + EdgesMsg between_chunk = 3; +} \ No newline at end of file diff --git a/pychunkedgraph/io/protobuf/chunkEdges_pb2.py b/pychunkedgraph/io/protobuf/chunkEdges_pb2.py new file mode 100644 index 000000000..b90c15c39 --- /dev/null +++ b/pychunkedgraph/io/protobuf/chunkEdges_pb2.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: chunkEdges.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63hunkEdges.proto\x12\x05\x65\x64ges\"S\n\x08\x45\x64gesMsg\x12\x11\n\tnode_ids1\x18\x01 \x01(\x0c\x12\x11\n\tnode_ids2\x18\x02 \x01(\x0c\x12\x12\n\naffinities\x18\x03 \x01(\x0c\x12\r\n\x05\x61reas\x18\x04 \x01(\x0c\"\x80\x01\n\rChunkEdgesMsg\x12!\n\x08in_chunk\x18\x01 \x01(\x0b\x32\x0f.edges.EdgesMsg\x12$\n\x0b\x63ross_chunk\x18\x02 \x01(\x0b\x32\x0f.edges.EdgesMsg\x12&\n\rbetween_chunk\x18\x03 \x01(\x0b\x32\x0f.edges.EdgesMsgb\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'chunkEdges_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _EDGESMSG._serialized_start=27 + _EDGESMSG._serialized_end=110 + _CHUNKEDGESMSG._serialized_start=113 + _CHUNKEDGESMSG._serialized_end=241 +# @@protoc_insertion_point(module_scope) diff --git a/pychunkedgraph/jobs/__init__.py b/pychunkedgraph/jobs/__init__.py new file mode 100644 index 000000000..8eb11b638 --- /dev/null +++ b/pychunkedgraph/jobs/__init__.py @@ -0,0 +1,3 @@ +""" +Functions meant to be run as jobs/cron jobs. +""" \ No newline at end of file diff --git a/pychunkedgraph/jobs/alert.py b/pychunkedgraph/jobs/alert.py new file mode 100644 index 000000000..84fe3be17 --- /dev/null +++ b/pychunkedgraph/jobs/alert.py @@ -0,0 +1,19 @@ +from typing import Iterable + + +def send_email(to: Iterable[str], subject: str, message: str) -> None: + """ + Uses GMail SMTP server to send alerts. + """ + from os import environ + from smtplib import SMTP_SSL + + email = environ["ALERT_BOT_EMAIL_ID"] + password = environ["ALERT_BOT_EMAIL_PASSWORD"] + text = f"From: AlertBot <{email}>" f"\nSubject: {subject}\n\n{message}" + + server = SMTP_SSL("smtp.gmail.com", 465) + server.ehlo() + server.login(email, password) + server.sendmail(password, to, text) + server.close() diff --git a/pychunkedgraph/jobs/export/__init__.py b/pychunkedgraph/jobs/export/__init__.py new file mode 100644 index 000000000..95a42a5b0 --- /dev/null +++ b/pychunkedgraph/jobs/export/__init__.py @@ -0,0 +1,3 @@ +""" +Export data out of chunkedgraph. +""" \ No newline at end of file diff --git a/pychunkedgraph/jobs/export/datastore/__init__.py b/pychunkedgraph/jobs/export/datastore/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pychunkedgraph/jobs/export/datastore/operation_logs.py b/pychunkedgraph/jobs/export/datastore/operation_logs.py new file mode 100644 index 000000000..b41d574fa --- /dev/null +++ b/pychunkedgraph/jobs/export/datastore/operation_logs.py @@ -0,0 +1,84 @@ +""" +Export operation log data to Datastore. + +Accepts a comma separated list of graph IDs +as command line argument or env variable `GRAPH_IDS` + +Uses `GOOGLE_APPLICATION_CREDENTIALS` service account +if `OPERATION_LOGS_DATASTORE_CREDENTIALS` not provided. +""" +from typing import List +from typing import Optional +from os import environ + +from ....graph import ChunkedGraph + + +def _get_chunkedgraphs(graph_ids: Optional[str] = None) -> List[ChunkedGraph]: + """ + `graph_ids` comma separated list of graph IDs + If None, load from env variable `GRAPH_IDS` + """ + if not graph_ids: + try: + graph_ids = environ["GRAPH_IDS"] + except KeyError: + raise KeyError("Environment variable `GRAPH_IDS` is required.") + + chunkedgraphs = [] + for id_ in [id_.strip() for id_ in graph_ids.split(",")]: + chunkedgraphs.append(ChunkedGraph(graph_id=id_)) + return chunkedgraphs + + +def run_export(chunkedgraphs: List[ChunkedGraph], datastore_ns: str = None) -> None: + """ + Export job for processing edit logs and storing them in Datastore. + Default namespace: pychunkedgraph.export.to.datastore.config.DEFAULT_NS + + Sends an email alert when there are failed writes. + These cases must be inspected manually and are expected to be rare. + """ + from ...alert import send_email + from ....export.to.datastore import export_operation_logs + + if not datastore_ns: + datastore_ns = environ.get("DATASTORE_NS") + + for cg in chunkedgraphs: + print(f"Start log export job for {cg.graph_id}") + failed = export_operation_logs(cg, namespace=datastore_ns) + if failed: + print(f"Failed writes {failed}") + alert_emails = environ["EMAIL_LIST_FAILED_WRITES"] + send_email( + [e.strip() for e in alert_emails.split(",")], + "Failed Writes", + f"TEST: There are {failed} failed writes for {cg.graph_id}.", + ) + + +if __name__ == "__main__": + from sys import argv + from ....export.to.datastore.config import DEFAULT_NS + + assert len(argv) <= 3 + + graph_ids = None + try: + graph_ids = argv[1] + except IndexError: + print("`graph_ids` not provided, using env variable `GRAPH_IDS`") + + datastore_ns = None + try: + datastore_ns = argv[2] + except IndexError: + print("`datastore_namespace` not provided, using env variable `DATASTORE_NS`.") + print( + f"Default `pychunkedgraph.export.to.datastore.config.DEFAULT_NS`\n" + f"{DEFAULT_NS} will be used\n" + f"if env variable `DATASTORE_NS` is not provided." + ) + + run_export(_get_chunkedgraphs(graph_ids=graph_ids), datastore_ns=datastore_ns) diff --git a/pychunkedgraph/jobs/repair/__init__.py b/pychunkedgraph/jobs/repair/__init__.py new file mode 100644 index 000000000..4322f75fd --- /dev/null +++ b/pychunkedgraph/jobs/repair/__init__.py @@ -0,0 +1,8 @@ +""" +Repair failed edit operations. +There is small possibility of an edit operation failing +in the middle of persisting changes, this can corrupt the +chunkedgraph. When this happens a root is locked indefinitely +to prevent any more edits on it until the issue is fixed. +This is meant to be run manually after inspecting such cases. +""" diff --git a/pychunkedgraph/jobs/repair/main.py b/pychunkedgraph/jobs/repair/main.py new file mode 100644 index 000000000..570d59eff --- /dev/null +++ b/pychunkedgraph/jobs/repair/main.py @@ -0,0 +1,69 @@ +""" +Re run failed edit operations. +These jobs get data (failed operations) from Google Datastore. +""" +from ...graph import ChunkedGraph +from ...export.models import OperationLog + + +def _repair_operation(cg: ChunkedGraph, log: OperationLog): + from datetime import timedelta + from ...graph.operation import GraphEditOperation + + operation = GraphEditOperation.from_operation_id( + cg, log.id, multicut_as_split=False, privileged_mode=True + ) + ts = log["timestamp"] + result = operation.execute( + operation_id=log.id, + parent_ts=ts - timedelta(seconds=0.1), + override_ts=ts + timedelta(microseconds=(ts.microsecond % 1000) + 10), + ) + + old_roots = operation._update_root_ids() + print("roots", old_roots, result.new_root_ids) + + for root_ in old_roots: + cg.client.unlock_indefinitely_locked_root(root_, result.operation_id) + + +def _repair_failed_operations(graph_id: str = None, datastore_ns: str = None): + from os import environ + from datetime import datetime + from datetime import timedelta + from google.cloud import datastore + from ...graph.operation import GraphEditOperation + + # if not graph_id: + # graph_id = environ["GRAPH_IDS"] + # if not datastore_ns: + # datastore_ns = environ.get("DATASTORE_NS") + + graph_id = "minnie3_v1" + datastore_ns = "pcg_test" + + try: + client = datastore.Client().from_service_account_json( + environ["OPERATION_LOGS_DATASTORE_CREDENTIALS"] + ) + except KeyError: + print("Datastore credentials not provided.") + print(f"Using {environ['GOOGLE_APPLICATION_CREDENTIALS']}") + # use GOOGLE_APPLICATION_CREDENTIALS + # this is usually "/root/.cloudvolume/secrets/.json" + client = datastore.Client() + + query = client.query(kind=f"{graph_id}_failed", namespace=datastore_ns) + cg = ChunkedGraph(graph_id=graph_id) + for log in query.fetch(): + print(f"Re-trying operation ID {log.id}") + _repair_operation(cg, log) + client.delete(log.key) + + +def repair_operations(): + _repair_failed_operations() + + +if __name__ == "__main__": + repair_operations() diff --git a/pychunkedgraph/logging/flask_log_db.py b/pychunkedgraph/logging/flask_log_db.py deleted file mode 100644 index a2f946a44..000000000 --- a/pychunkedgraph/logging/flask_log_db.py +++ /dev/null @@ -1,96 +0,0 @@ -import os -import json -from google.cloud import datastore - -HOME = os.path.expanduser('~') - -# Setting environment wide credential path -os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = \ - HOME + "/.cloudvolume/secrets/google-secret.json" - - -class FlaskLogDatabase(object): - def __init__(self, table_id, project_id="neuromancer-seung-import", - client=None, credentials=None): - self._table_id = table_id - if client is not None: - self._client = client - else: - self._client = datastore.Client(project=project_id, - credentials=credentials) - @property - def table_id(self): - return self._table_id - - @property - def client(self): - return self._client - - @property - def namespace(self): - return 'pychunkedgraphserverdb' - - @property - def kind(self): - return "flask_log_%s" % self.table_id - - def add_success_log(self, user_id, user_ip, request_time, response_time, - url, request_type, request_data=None): - self._add_log(log_type="info", user_id=user_id, user_ip=user_ip, - request_time=request_time, response_time=response_time, - url=url, request_data=request_data, - request_type=request_type) - - def add_internal_error_log(self, user_id, user_ip, request_time, - response_time, url, err_msg, request_data=None): - self._add_log(log_type="internal_error", user_id=user_id, - user_ip=user_ip, request_time=request_time, - response_time=response_time, url=url, - request_data=request_data, msg=err_msg) - - def add_unhandled_exception_log(self, user_id, user_ip, request_time, - response_time, url, err_msg, - request_data=None): - self._add_log(log_type="unhandled_exception", user_id=user_id, - user_ip=user_ip, request_time=request_time, - response_time=response_time, url=url, - request_data=request_data, msg=err_msg) - - def _add_log(self, log_type, user_id, user_ip, request_time, response_time, - url, request_type=None, request_arg=None, request_data=None, - msg=None): - # Extract relevant information and build entity - - key = self.client.key(self.kind, namespace=self.namespace) - entity = datastore.Entity(key) - - url_split = url.split("/") - - if request_type: - if "?" in url_split[-1]: - request_type = url_split[-1].split("?")[0] - request_opt_arg = url_split[-1].split("?")[1] - else: - request_type = url_split[-1] - request_opt_arg = None - - if len(request_data) == 0: - request_data = None - else: - request_data = json.loads(request_data) - - entity['type'] = log_type - entity['user_id'] = user_id - entity['user_ip'] = user_ip - entity['date'] = request_time - entity['response_time(ms)'] = response_time - entity['request_type'] = request_type - entity['request_arg'] = request_arg - entity['request_data'] = str(request_data) - entity['request_opt_arg'] = request_opt_arg - entity['url'] = url - entity['msg'] = msg - - self.client.put(entity) - - return entity.key.id diff --git a/pychunkedgraph/logging/jsonformatter.py b/pychunkedgraph/logging/jsonformatter.py index 7c417d42b..79c99b343 100644 --- a/pychunkedgraph/logging/jsonformatter.py +++ b/pychunkedgraph/logging/jsonformatter.py @@ -1,14 +1,14 @@ -from pythonjsonlogger import jsonlogger - - -class JsonFormatter(jsonlogger.JsonFormatter): - def add_fields(self, log_record, record, message_dict): - """Remap `log_record`s fields to fluentd-gcp counterparts.""" - super(JsonFormatter, self).add_fields(log_record, record, message_dict) - log_record["time"] = log_record.get("time", log_record["asctime"]) - log_record["severity"] = log_record.get( - "severity", log_record["levelname"]) - log_record["source"] = log_record.get("source", log_record["name"]) - del log_record["asctime"] - del log_record["levelname"] - del log_record["name"] +from pythonjsonlogger import jsonlogger + + +class JsonFormatter(jsonlogger.JsonFormatter): + def add_fields(self, log_record, record, message_dict): + """Remap `log_record`s fields to fluentd-gcp counterparts.""" + super(JsonFormatter, self).add_fields(log_record, record, message_dict) + log_record["time"] = log_record.get("time", log_record["asctime"]) + log_record["severity"] = log_record.get( + "severity", log_record["levelname"]) + log_record["source"] = log_record.get("source", log_record["name"]) + del log_record["asctime"] + del log_record["levelname"] + del log_record["name"] diff --git a/pychunkedgraph/logging/log_db.py b/pychunkedgraph/logging/log_db.py new file mode 100644 index 000000000..89680500a --- /dev/null +++ b/pychunkedgraph/logging/log_db.py @@ -0,0 +1,136 @@ +# pylint: disable=invalid-name, missing-docstring, too-many-arguments + +import os +import threading +import time +import queue +from datetime import datetime + +from google.api_core.exceptions import GoogleAPIError +from datastoreflex import DatastoreFlex + + +ENABLE_LOGS = os.environ.get("PCG_SERVER_ENABLE_LOGS", "") != "" +LOG_DB_CACHE = {} + +EXCLUDE_FROM_INDEX = os.environ.get( + "PCG_SERVER_LOGS_INDEX_EXCLUDE", "args, time_ms, user_id" +) +EXCLUDE_FROM_INDEX = tuple(attr.strip() for attr in EXCLUDE_FROM_INDEX.split(",")) + + +class LogDB: + def __init__(self, graph_id: str, client: DatastoreFlex): + self._graph_id = graph_id + self._client = client + self._kind = f"server_logs_{self._graph_id}" + self._q = queue.Queue() + + @property + def graph_id(self): + return self._graph_id + + @property + def client(self): + return self._client + + def log_endpoint( + self, + path, + endpoint, + args, + user_id, + operation_id, + request_ts, + response_time, + ): + item = { + "name": path, + "endpoint": endpoint, + "args": args, + "user_id": str(user_id), + "request_ts": request_ts, + "time_ms": response_time, + } + if operation_id is not None: + item["operation_id"] = int(operation_id) + self._q.put(item) + + def log_code_block(self, name: str, operation_id, timestamp, time_ms, **kwargs): + item = { + "name": name, + "operation_id": int(operation_id), + "request_ts": timestamp, + "time_ms": time_ms, + } + item.update(kwargs) + self._q.put(item) + + def log_entity(self): + while True: + try: + item = self._q.get_nowait() + key = self.client.key(self._kind, namespace=self._client.namespace) + entity = self.client.entity( + key, exclude_from_indexes=EXCLUDE_FROM_INDEX + ) + entity.update(item) + self.client.put(entity) + except queue.Empty: + time.sleep(1) + + +def get_log_db(graph_id: str) -> LogDB: + try: + return LOG_DB_CACHE[graph_id] + except KeyError: + ... + + try: + project = os.environ["PCG_SERVER_LOGS_PROJECT"] + except KeyError as err: + raise GoogleAPIError(f"Datastore project env not set: {err}") from err + + namespace = os.environ.get("PCG_SERVER_LOGS_NS", "pcg_server_logs_test") + client = DatastoreFlex(project=project, namespace=namespace) + + log_db = LogDB(graph_id, client=client) + LOG_DB_CACHE[graph_id] = log_db + # use threads to exclude time reguired to log + threading.Thread(target=log_db.log_entity, daemon=True).start() + return log_db + + +class TimeIt: + names = [] + operation_id = -1 + + def __init__(self, name: str, graph_id: str, operation_id=-1, **kwargs): + self.names.append(name) + self._start = None + self._graph_id = graph_id + self._ts = datetime.utcnow() + self._kwargs = kwargs + if operation_id != -1: + self.operation_id = operation_id + + def __enter__(self): + self._start = time.time() + + def __exit__(self, *args): + if ENABLE_LOGS is False: + return + + time_ms = (time.time() - self._start) * 1000 + try: + log_db = get_log_db(self._graph_id) + log_db.log_code_block( + name=".".join(self.names), + operation_id=self.operation_id, + timestamp=self._ts, + time_ms=time_ms, + **self._kwargs, + ) + except GoogleAPIError: + ... + self.names.pop() diff --git a/pychunkedgraph/logging/performance.py b/pychunkedgraph/logging/performance.py index c2d9e99c9..e6d63eb59 100644 --- a/pychunkedgraph/logging/performance.py +++ b/pychunkedgraph/logging/performance.py @@ -9,8 +9,8 @@ from matplotlib import pyplot as plt -from pychunkedgraph.logging import flask_log_db -from pychunkedgraph.backend import chunkedgraph +from pychunkedgraph.logging import log_db +from pychunkedgraph.graph import chunkedgraph from google.cloud import datastore from google.auth import credentials, default as default_creds @@ -22,7 +22,7 @@ def readout_log_db(table_id, filters, cols, credentials, project_id = default_creds() client = datastore.Client(project=project_id, credentials=credentials) - log_db = flask_log_db.FlaskLogDatabase(table_id, client=client) + log_db = log_db.LogDatabase(table_id, client=client) query = log_db.client.query(kind=log_db.kind, namespace=log_db.namespace) diff --git a/pychunkedgraph/meshing/manifest/__init__.py b/pychunkedgraph/meshing/manifest/__init__.py new file mode 100644 index 000000000..63e1e7f7c --- /dev/null +++ b/pychunkedgraph/meshing/manifest/__init__.py @@ -0,0 +1,22 @@ +# pylint: disable=invalid-name, missing-docstring + +import numpy as np + +from .cache import ManifestCache +from .sharded import get_children_before_start_layer +from .sharded import verified_manifest as verified_manifest_sharded +from .sharded import speculative_manifest as speculative_manifest_sharded + + +def get_highest_child_nodes_with_meshes( + cg, + node_id: np.uint64, + start_layer: int, + bounding_box=None, +): + return verified_manifest_sharded( + cg, + node_id, + start_layer=start_layer, + bounding_box=bounding_box, + ) diff --git a/pychunkedgraph/meshing/manifest/cache.py b/pychunkedgraph/meshing/manifest/cache.py new file mode 100644 index 000000000..f38a830c2 --- /dev/null +++ b/pychunkedgraph/meshing/manifest/cache.py @@ -0,0 +1,163 @@ +# pylint: disable=invalid-name, missing-docstring, import-outside-toplevel, broad-exception-caught + +import os +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import List + +import redis +import numpy as np + +DOES_NOT_EXIST = "X" +INITIAL_PATH_PREFIX = "initial_path_prefix" + +REDIS_HOST = os.environ.get("MANIFEST_CACHE_REDIS_HOST", "localhost") +REDIS_PORT = os.environ.get("MANIFEST_CACHE_REDIS_PORT", "6379") +REDIS_PASSWORD = os.environ.get("MANIFEST_CACHE_REDIS_PASSWORD", "") +REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0" + + +REDIS = redis.Redis.from_url(REDIS_URL, socket_connect_timeout=1) +try: + REDIS.ping() + REDIS = redis.Redis.from_url(REDIS_URL) +except Exception: + REDIS = None + + +class ManifestCache: + def __init__(self, namespace: str, initial: Optional[bool] = True) -> None: + self._initial = initial + self._namespace = namespace + self._initial_path_prefix = None + + @property + def initial(self) -> str: + return self._initial + + @property + def namespace(self) -> str: + return self._namespace + + @property + def initial_path_prefix(self) -> str: + if self._initial_path_prefix is None: + key = f"{self.namespace}:{INITIAL_PATH_PREFIX}" + path_prefix = REDIS.get(key) + self._initial_path_prefix = path_prefix.decode() if path_prefix else "" + return self._initial_path_prefix + + @initial_path_prefix.setter + def initial_path_prefix(self, path_prefix: str): + self._initial_path_prefix = path_prefix + key = f"{self.namespace}:{INITIAL_PATH_PREFIX}" + REDIS.set(key, path_prefix) + + def get_fragments(self, node_ids) -> Tuple[Dict, List[np.uint64], List[np.uint64]]: + if self.initial is True: + return self._get_cached_initial_fragments(node_ids) + return self._get_cached_dynamic_fragments(node_ids) + + def set_fragments(self, fragments_d: Dict, not_existing: List[np.uint64] = None): + if not_existing is None: + not_existing = [] + if self.initial is True: + self._set_cached_initial_fragments(fragments_d, not_existing) + else: + self._set_cached_dynamic_fragments(fragments_d, not_existing) + + def clear_fragments(self, node_ids) -> None: + if REDIS is None: + return + + keys = [f"{self.namespace}:{n}" for n in node_ids] + REDIS.delete(*keys) + + def _get_cached_initial_fragments(self, node_ids: List[np.uint64]): + if REDIS is None: + return {}, node_ids, [] + + pipeline = REDIS.pipeline() + for node_id in node_ids: + pipeline.get(f"{self.namespace}:{node_id}") + + result = {} + not_cached = [] + not_existing = [] + fragments = pipeline.execute() + for node_id, fragment in zip(node_ids, fragments): + if fragment is None: + not_cached.append(node_id) + continue + fragment = fragment.decode() + try: + path, offset, size = fragment.split(":") + result[node_id] = [path, int(offset), int(size)] + except ValueError: + not_existing.append(node_id) + return result, not_cached, not_existing + + def _get_cached_dynamic_fragments(self, node_ids: List[np.uint64]): + if REDIS is None: + return {}, node_ids, [] + + pipeline = REDIS.pipeline() + for node_id in node_ids: + pipeline.get(f"{self.namespace}:{node_id}") + + result = {} + not_cached = [] + not_existing = [] + fragments = pipeline.execute() + for node_id, fragment in zip(node_ids, fragments): + if fragment is None: + not_cached.append(node_id) + continue + fragment = fragment.decode() + if fragment == DOES_NOT_EXIST: + not_existing.append(node_id) + else: + result[node_id] = fragment + return result, not_cached, not_existing + + def _set_cached_initial_fragments( + self, fragments_d: Dict, not_existing: List[np.uint64] + ) -> None: + if REDIS is None: + return + + prefix_idx = 0 + if len(fragments_d) > 0: + fragment_info = next(iter(fragments_d.values())) + path, _, _ = fragment_info + idx = path.find("initial/") + prefix_idx = idx + 8 + path_prefix = path[:prefix_idx] + self.initial_path_prefix = path_prefix + + pipeline = REDIS.pipeline() + for node_id, fragment_info in fragments_d.items(): + path, offset, size = fragment_info + key = f"{self.namespace}:{node_id}" + pipeline.set(key, f"{path[prefix_idx:]}:{offset}:{size}") + + for node_id in not_existing: + pipeline.set(f"{self.namespace}:{node_id}", DOES_NOT_EXIST) + + pipeline.execute() + + def _set_cached_dynamic_fragments( + self, fragments_d: Dict, not_existing: List[np.uint64] + ) -> None: + if REDIS is None: + return + + pipeline = REDIS.pipeline() + for node_id, fragment in fragments_d.items(): + pipeline.set(f"{self.namespace}:{node_id}", fragment) + + for node_id in not_existing: + pipeline.set(f"{self.namespace}:{node_id}", DOES_NOT_EXIST) + + pipeline.execute() diff --git a/pychunkedgraph/meshing/manifest/sharded.py b/pychunkedgraph/meshing/manifest/sharded.py new file mode 100644 index 000000000..2576fcb2f --- /dev/null +++ b/pychunkedgraph/meshing/manifest/sharded.py @@ -0,0 +1,109 @@ +# pylint: disable=invalid-name, missing-docstring, import-outside-toplevel + +from time import time + +import numpy as np +from cloudvolume import CloudVolume + +from .utils import get_children_before_start_layer +from ...graph import ChunkedGraph +from ...graph.types import empty_1d +from ...graph.utils.basetypes import NODE_ID +from ...graph.chunks import utils as chunk_utils + + +def verified_manifest( + cg: ChunkedGraph, + node_id: np.uint64, + start_layer: int, + bounding_box=None, +): + from .utils import get_mesh_paths + + bounding_box = chunk_utils.normalize_bounding_box( + cg.meta, bounding_box, bbox_is_coordinate=True + ) + node_ids = get_children_before_start_layer( + cg, node_id, start_layer, bounding_box=bounding_box + ) + print(f"children before start_layer {len(node_ids)}") + + start = time() + result = get_mesh_paths(cg, node_ids) + node_ids = np.fromiter(result.keys(), dtype=NODE_ID) + + mesh_files = [] + for val in result.values(): + try: + path, offset, size = val + path = path.split("initial/")[-1] + mesh_files.append(f"~{path}:{offset}:{size}") + except ValueError: + mesh_files.append(val) + print(f"shard lookups took {time() - start}") + return node_ids, mesh_files + + +def speculative_manifest( + cg: ChunkedGraph, + node_id: NODE_ID, + start_layer: int, + stop_layer: int = 2, + bounding_box=None, +): + """ + This assumes children IDs have meshes. + Not checking for their existence reduces latency. + """ + from .utils import check_skips + from .utils import segregate_node_ids + from ..meshgen_utils import get_mesh_name + from ..meshgen_utils import get_json_info + + if start_layer is None: + start_layer = cg.meta.custom_data.get("mesh", {}).get("max_layer", 2) + + start = time() + bounding_box = chunk_utils.normalize_bounding_box( + cg.meta, bounding_box, bbox_is_coordinate=True + ) + node_ids = get_children_before_start_layer( + cg, node_id, start_layer=start_layer, bounding_box=bounding_box + ) + print("children_before_start_layer", time() - start) + + start = time() + result = [empty_1d] + node_layers = cg.get_chunk_layers(node_ids) + while np.any(node_layers > stop_layer): + result.append(node_ids[node_layers == stop_layer]) + ids_ = node_ids[node_layers > stop_layer] + ids_, skips = check_skips(cg, ids_) + + result.append(ids_) + node_ids = skips.copy() + node_layers = cg.get_chunk_layers(node_ids) + + result.append(node_ids[node_layers == stop_layer]) + print("chilren IDs", len(result), time() - start) + + readers = CloudVolume( # pylint: disable=no-member + "graphene://https://localhost/segmentation/table/dummy", + mesh_dir=cg.meta.custom_data.get("mesh", {}).get("dir", "graphene_meshes"), + info=get_json_info(cg), + ).mesh.readers + + node_ids = np.concatenate(result) + initial_ids, new_ids = segregate_node_ids(cg, node_ids) + + # get shards for initial IDs + layers = cg.get_chunk_layers(initial_ids) + chunk_ids = cg.get_chunk_ids_from_node_ids(initial_ids) + mesh_shards = [] + for id_, layer, chunk_id in zip(initial_ids, layers, chunk_ids): + fname, minishard = readers[layer].compute_shard_location(id_) + mesh_shards.append(f"~{id_}:{layer}:{chunk_id}:{fname}:{minishard}") + + # get mesh files for new IDs + mesh_files = [f"{get_mesh_name(cg, id_)}" for id_ in new_ids] + return np.concatenate([initial_ids, new_ids]), mesh_shards + mesh_files diff --git a/pychunkedgraph/meshing/manifest/sharded_format.md b/pychunkedgraph/meshing/manifest/sharded_format.md new file mode 100644 index 000000000..828c84cc2 --- /dev/null +++ b/pychunkedgraph/meshing/manifest/sharded_format.md @@ -0,0 +1,29 @@ +## Manifest formats + +### Legacy format +Legacy format has the standard format defined by neuroglancer - `{segment_id}:{lod}:{bounding_box}`. + +### Sharded Graphene format +For large chunkedgraph datasets like `minnie65` sharding is necessary to reduce number of mesh files to control storage costs. + +There can be two types of mesh fragments in a manifest: `initial` and `dynamic`. +* `initial` + + Initials IDs are segment IDs generated at the time of chunkedgraph creation. Meshes for these are generated in the form of [shards](https://github.com/seung-lab/cloud-volume/wiki/Graphene#meshing). The following formats are used for mesh fragments of initial IDs, depending on whether existence of these shards is verified when generating the manifest. `~` is used to denote sharded format. + + With verification, `~{layer}/{shard_file}:{offset}:{size}`, this is unique for a segment ID. + + eg: for a semgent ID `173395595644372020`, `~2/425884686-0.shard:165832:217` + + Without verification, `~{segment_id}:{layer}:{chunk_id}:{fname}:{minishard_number}`, segment ID included to ensure fragment ID is unique. + + eg: `~173395595644372020:2:173395595644370944:425884686-0.shard:1` + + If verification is needed, the manifest includes fragment ID for a semgent ID only if it's mesh fragment exists. Without verification, fragment ID is included in manifest and assumed to exist. + +* `dynamic` + + For segment IDs generated by edit operations during proofreading, legacy format is used to name mesh fragments. + + +For all formats, `prepend_seg_ids=true` query parameter can be used to add `~{segment_id}:` as prefix to fragment ID in the manifest. This can be used in neuroglancer to map segment ID to mesh fragment in cache, this helps avoid redownloading unaffected fragments after an edit operation. diff --git a/pychunkedgraph/meshing/manifest/utils.py b/pychunkedgraph/meshing/manifest/utils.py new file mode 100644 index 000000000..67e600653 --- /dev/null +++ b/pychunkedgraph/meshing/manifest/utils.py @@ -0,0 +1,238 @@ +# pylint: disable=invalid-name, missing-docstring, import-outside-toplevel + +from datetime import datetime +from typing import List +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import Sequence + +import numpy as np +from cloudfiles import CloudFiles +from cloudvolume import CloudVolume + +from .cache import ManifestCache +from ..meshgen_utils import get_mesh_name +from ..meshgen_utils import get_json_info +from ...graph import ChunkedGraph +from ...graph.types import empty_1d +from ...graph.utils.basetypes import NODE_ID +from ...graph.utils import generic as misc_utils + + +def _del_none_keys(d: dict): + none_keys = [] + d_new = dict(d) + for k, v in d.items(): + if v: + continue + none_keys.append(k) + del d_new[k] + return d_new, none_keys + + +def _get_children(cg, node_ids: Sequence[np.uint64], children_cache: Dict): + """ + Helper function that makes use of cache. + `_check_skips` also needs to know about children so cache is shared between them. + """ + + if len(node_ids) == 0: + return empty_1d.copy() + node_ids = np.array(node_ids, dtype=NODE_ID) + mask = np.in1d(node_ids, np.fromiter(children_cache.keys(), dtype=NODE_ID)) + children_d = cg.get_children(node_ids[~mask]) + children_cache.update(children_d) + + children = [empty_1d] + for id_ in node_ids: + children.append(children_cache[id_]) + return np.concatenate(children) + + +def _get_initial_meshes( + cg, + shard_readers, + node_ids: Sequence[np.uint64], + stop_layer: int = 2, +) -> Dict: + children_cache = {} + result = {} + if len(node_ids) == 0: + return result + + manifest_cache = ManifestCache(cg.graph_id, initial=True) + node_layers = cg.get_chunk_layers(node_ids) + stop_layer_ids = [node_ids[node_layers == stop_layer]] + + while np.any(node_layers > stop_layer): + stop_layer_ids.append(node_ids[node_layers == stop_layer]) + ids_ = node_ids[node_layers > stop_layer] + ids_, skips = check_skips(cg, ids_, children_cache=children_cache) + + _result, _not_cached, _not_existing = manifest_cache.get_fragments(ids_) + result.update(_result) + + result_ = shard_readers.initial_exists(_not_cached, return_byte_range=True) + result_, not_existing_ = _del_none_keys(result_) + + manifest_cache.set_fragments(result_, not_existing_) + result.update(result_) + + node_ids = _get_children(cg, _not_existing + not_existing_, children_cache) + node_ids = np.concatenate([node_ids, skips]) + node_layers = cg.get_chunk_layers(node_ids) + + # remainder ids + stop_layer_ids = np.concatenate( + [*stop_layer_ids, node_ids[node_layers == stop_layer]] + ) + + _result, _not_cached, _ = manifest_cache.get_fragments(stop_layer_ids) + result.update(_result) + + result_ = shard_readers.initial_exists(_not_cached, return_byte_range=True) + result_, not_existing_ = _del_none_keys(result_) + manifest_cache.set_fragments(result_, not_existing_) + + result.update(result_) + return result + + +def _get_dynamic_meshes(cg, node_ids: Sequence[np.uint64]) -> Tuple[Dict, List]: + result = {} + not_existing = [] + if len(node_ids) == 0: + return result, not_existing + + mesh_dir = cg.meta.custom_data.get("mesh", {}).get("dir", "graphene_meshes") + mesh_path = f"{cg.meta.data_source.WATERSHED}/{mesh_dir}/dynamic" + cf = CloudFiles(mesh_path) + manifest_cache = ManifestCache(cg.graph_id, initial=False) + + result_, not_cached, not_existing_ = manifest_cache.get_fragments(node_ids) + not_existing.extend(not_existing_) + result.update(result_) + + filenames = [get_mesh_name(cg, id_) for id_ in not_cached] + existence_dict = cf.exists(filenames) + + result_ = {} + for mesh_key in existence_dict: + node_id = np.uint64(mesh_key.split(":")[0]) + if existence_dict[mesh_key]: + result_[node_id] = mesh_key + continue + not_existing.append(node_id) + + manifest_cache.set_fragments(result_) + result.update(result_) + return result, np.array(not_existing, dtype=NODE_ID) + + +def _get_initial_and_dynamic_meshes( + cg, + shard_readers: Dict, + node_ids: Sequence[np.uint64], +) -> Tuple[Dict, Dict, List]: + if len(node_ids) == 0: + return {}, {}, [] + + node_ids = np.array(node_ids, dtype=NODE_ID) + initial_ids, new_ids = segregate_node_ids(cg, node_ids) + print("new_ids, initial_ids", new_ids.size, initial_ids.size) + + initial_meshes_d = _get_initial_meshes(cg, shard_readers, initial_ids) + new_meshes_d, missing_ids = _get_dynamic_meshes(cg, new_ids) + return initial_meshes_d, new_meshes_d, missing_ids + + +def check_skips( + cg, node_ids: Sequence[np.uint64], children_cache: Optional[Dict] = None +): + """ + If a node ID has a single child, it is considered a skip. + Such IDs won't have meshes because the child mesh will be identical. + """ + + if children_cache is None: + children_cache = {} + + layers = cg.get_chunk_layers(node_ids) + skips = [] + result = [empty_1d, node_ids[layers == 2]] + children_d = cg.get_children(node_ids[layers > 2]) + for p, c in children_d.items(): + if c.size > 1: + result.append([p]) + children_cache[p] = c + continue + assert c.size == 1, f"{p} does not seem to have children." + skips.append(c[0]) + + return np.concatenate(result), np.array(skips, dtype=np.uint64) + + +def segregate_node_ids(cg, node_ids): + """ + Group node IDs based on timestamp + initial = created at the time of ingest + new = created by proofreading edit operations + """ + + initial_ts = cg.meta.custom_data["mesh"]["initial_ts"] + initial_mesh_dt = np.datetime64(datetime.fromtimestamp(initial_ts)) + node_ids_ts = cg.get_node_timestamps(node_ids) + initial_mesh_mask = node_ids_ts < initial_mesh_dt + initial_ids = node_ids[initial_mesh_mask] + new_ids = node_ids[~initial_mesh_mask] + return initial_ids, new_ids + + +def get_mesh_paths( + cg, + node_ids: Sequence[np.uint64], + stop_layer: int = 2, +) -> Dict: + shard_readers = CloudVolume( # pylint: disable=no-member + "graphene://https://localhost/segmentation/table/dummy", + mesh_dir=cg.meta.custom_data.get("mesh", {}).get("dir", "graphene_meshes"), + info=get_json_info(cg), + ).mesh + + result = {} + node_layers = cg.get_chunk_layers(node_ids) + while np.any(node_layers > stop_layer): + node_ids = node_ids[node_layers > 1] + resp = _get_initial_and_dynamic_meshes(cg, shard_readers, node_ids) + initial_meshes_d, new_meshes_d, missing_ids = resp + result.update(initial_meshes_d) + result.update(new_meshes_d) + node_ids = cg.get_children(missing_ids, flatten=True) + node_layers = cg.get_chunk_layers(node_ids) + + # check for left over level 2 IDs + node_ids = node_ids[node_layers > 1] + resp = _get_initial_and_dynamic_meshes(cg, shard_readers, node_ids) + initial_meshes_d, new_meshes_d, _ = resp + result.update(initial_meshes_d) + result.update(new_meshes_d) + return result + + +def get_children_before_start_layer( + cg: ChunkedGraph, node_id: np.uint64, start_layer: int, bounding_box=None +): + if cg.get_chunk_layer(node_id) == 2: + return np.array([node_id], dtype=NODE_ID) + result = [empty_1d] + parents = np.array([node_id], dtype=np.uint64) + while parents.size: + children = cg.get_children(parents, flatten=True) + bound_mask = misc_utils.mask_nodes_by_bounding_box( + cg.meta, children, bounding_box=bounding_box + ) + layers = cg.get_chunk_layers(children) + result.append(children[(layers <= start_layer) & bound_mask]) + parents = children[(layers > start_layer) & bound_mask] + return np.concatenate(result) diff --git a/pychunkedgraph/meshing/mesh_analysis.py b/pychunkedgraph/meshing/mesh_analysis.py new file mode 100644 index 000000000..97bb28f5b --- /dev/null +++ b/pychunkedgraph/meshing/mesh_analysis.py @@ -0,0 +1,101 @@ +from cloudvolume import CloudVolume +import numpy as np +import os +from pychunkedgraph.meshing import meshgen, meshgen_utils + + +def compute_centroid_by_range(vertices): + bbox_min = np.amin(vertices, axis=0) + bbox_max = np.amax(vertices, axis=0) + return bbox_min + ((bbox_max - bbox_min) / 2) + + +def compute_centroid_with_chunk_boundary(cg, vertices, l2_id, last_l2_id): + """ + Given a level 2 id, the vertices of its mesh, and the level 2 id preceding it in + a path, return the center point of the mesh on the chunk boundary separating the two + ids, and the center point of the entire mesh. + :param cg: ChunkedGraph object + :param vertices: [[np.float]] + :param l2_id: np.uint64 + :param last_l2_id: np.uint64 or None + :return: [np.float] + """ + centroid_by_range = compute_centroid_by_range(vertices) + if last_l2_id is None: + return [centroid_by_range] + l2_id_cc = cg.get_chunk_coordinates(l2_id) + last_l2_id_cc = cg.get_chunk_coordinates(last_l2_id) + + # Given the coordinates of the two level 2 ids, find the chunk boundary + axis_change = 2 + look_for_max = True + if l2_id_cc[0] != last_l2_id_cc[0]: + axis_change = 0 + elif l2_id_cc[1] != last_l2_id_cc[1]: + axis_change = 1 + if np.sum(l2_id_cc - last_l2_id_cc) > 0: + look_for_max = False + if look_for_max: + value_to_filter = np.amax(vertices[:, axis_change]) + else: + value_to_filter = np.amin(vertices[:, axis_change]) + chunk_boundary_vertices = vertices[ + np.where(vertices[:, axis_change] == value_to_filter) + ] + + # Get the center point of the mesh on the chunk boundary + bbox_min = np.amin(chunk_boundary_vertices, axis=0) + bbox_max = np.amax(chunk_boundary_vertices, axis=0) + return [bbox_min + ((bbox_max - bbox_min) / 2), centroid_by_range] + + +def compute_mesh_centroids_of_l2_ids(cg, l2_ids, flatten=False): + """ + Given a list of l2_ids, return a tuple containing a dict that maps l2_ids to their + mesh's centroid (a global coordinate), and a list of the l2_ids for which the mesh does not exist. + :param cg: ChunkedGraph object + :param l2_ids: Sequence[np.uint64] + :return: Union[Dict[np.uint64, np.ndarray], [np.uint64], [np.uint64]] + """ + cv_sharded_mesh_dir = cg.meta.dataset_info["mesh"] + cv_unsharded_mesh_dir = cg.meta.dataset_info["mesh_metadata"][ + "unsharded_mesh_dir" + ] + cv_unsharded_mesh_path = os.path.join( + cg.meta.data_source.WATERSHED, + cv_sharded_mesh_dir, + cv_unsharded_mesh_dir, + ) + cv = CloudVolume( + f"graphene://https://localhost/segmentation/table/dummy", + mesh_dir=cv_sharded_mesh_dir, + info=meshgen_utils.get_json_info(cg), + ) + meshes = cv.mesh.get_meshes_on_bypass(l2_ids, allow_missing=True) + if flatten: + centroids_with_chunk_boundary_points = [] + else: + centroids_with_chunk_boundary_points = {} + last_l2_id = None + missing_l2_ids = [] + for l2_id_i in l2_ids: + l2_id = int(l2_id_i) + try: + l2_mesh_vertices = meshes[l2_id].vertices + if flatten: + centroids_with_chunk_boundary_points.extend( + compute_centroid_with_chunk_boundary( + cg, l2_mesh_vertices, l2_id, last_l2_id + ) + ) + else: + centroids_with_chunk_boundary_points[ + l2_id + ] = compute_centroid_with_chunk_boundary( + cg, l2_mesh_vertices, l2_id, last_l2_id + ) + except: + missing_l2_ids.append(l2_id) + last_l2_id = l2_id + return centroids_with_chunk_boundary_points, missing_l2_ids \ No newline at end of file diff --git a/pychunkedgraph/meshing/mesh_io.py b/pychunkedgraph/meshing/mesh_io.py index 282ff0b97..40c02bba0 100644 --- a/pychunkedgraph/meshing/mesh_io.py +++ b/pychunkedgraph/meshing/mesh_io.py @@ -167,7 +167,7 @@ def load_obj(self): norms.append(0) faces.append(face) - self._faces = np.array(faces, dtype=np.int) - 1 + self._faces = np.array(faces, dtype=int) - 1 self._vertices = np.array(vertices, dtype=np.float) self._normals = np.array(normals, dtype=np.float) diff --git a/pychunkedgraph/meshing/mesh_worker.py b/pychunkedgraph/meshing/mesh_worker.py new file mode 100644 index 000000000..6d786fc42 --- /dev/null +++ b/pychunkedgraph/meshing/mesh_worker.py @@ -0,0 +1,20 @@ +import argparse +from taskqueue import TaskQueue +import os +import pychunkedgraph.meshing.meshing_sqs + +if __name__ == '__main__': + ppid = os.getppid() # Save parent process ID + + def stop_fn_with_parent_health_check(): + if os.getppid() != ppid: + print("Parent process is gone. {} shutting down...".format(os.getpid())) + return True + return False + + parser = argparse.ArgumentParser() + parser.add_argument('--qurl', type=str, required=True) + parser.add_argument('--lease_seconds', type=int, default=100, help='no. of seconds that polling will lease a task before it becomes visible again') + args = parser.parse_args() + with TaskQueue(args.qurl, n_threads=0) as tq: + tq.poll(stop_fn=stop_fn_with_parent_health_check, lease_seconds=args.lease_seconds) \ No newline at end of file diff --git a/pychunkedgraph/meshing/meshengine.py b/pychunkedgraph/meshing/meshengine.py index fd2a461c3..615e6cdb6 100644 --- a/pychunkedgraph/meshing/meshengine.py +++ b/pychunkedgraph/meshing/meshengine.py @@ -3,7 +3,7 @@ import itertools import random -from pychunkedgraph.backend import chunkedgraph +from pychunkedgraph.graph import chunkedgraph from multiwrapper import multiprocessing_utils as mu from . import meshgen @@ -85,7 +85,7 @@ def mesh_multiple_layers(self, layers=None, bounding_box=None, if layers is None: layers = range(1, int(self.cg.n_layers + 1)) - layers = np.array(layers, dtype=np.int) + layers = np.array(layers, dtype=int) layers = layers[layers > 0] layers = layers[layers < self.highest_mesh_layer + 1] @@ -106,16 +106,16 @@ def mesh_single_layer(self, layer, bounding_box=None, block_factor=2, block_bounding_box_cg = \ [np.floor(dataset_bounding_box[:3] / - self.cg.chunk_size).astype(np.int), + self.cg.chunk_size).astype(int), np.ceil(dataset_bounding_box[3:] / - self.cg.chunk_size).astype(np.int)] + self.cg.chunk_size).astype(int)] if bounding_box is not None: bounding_box_cg = \ [np.floor(bounding_box[0] / - self.cg.chunk_size).astype(np.int), + self.cg.chunk_size).astype(int), np.ceil(bounding_box[1] / - self.cg.chunk_size).astype(np.int)] + self.cg.chunk_size).astype(int)] m = block_bounding_box_cg[0] < bounding_box_cg[0] block_bounding_box_cg[0][m] = bounding_box_cg[0][m] @@ -147,7 +147,7 @@ def mesh_single_layer(self, layer, bounding_box=None, block_factor=2, block_bounding_box_cg[1][2], block_factor)) - blocks = np.array(list(block_iter), dtype=np.int) + blocks = np.array(list(block_iter), dtype=int) cg_info = self.cg.get_serialized_info() del (cg_info['credentials']) @@ -176,11 +176,11 @@ def mesh_single_layer(self, layer, bounding_box=None, block_factor=2, def create_manifests_for_higher_layers(self, n_threads=1): root_id_max = self.cg.get_max_node_id( - self.cg.get_chunk_id(layer=np.int(self.cg.n_layers), - x=np.int(0), y=np.int(0), - z=np.int(0))) + self.cg.get_chunk_id(layer=int(self.cg.n_layers), + x=int(0), y=int(0), + z=int(0))) - root_id_blocks = np.linspace(1, root_id_max, n_threads*3).astype(np.int) + root_id_blocks = np.linspace(1, root_id_max, n_threads*3).astype(int) cg_info = self.cg.get_serialized_info() del (cg_info['credentials']) diff --git a/pychunkedgraph/meshing/meshgen.py b/pychunkedgraph/meshing/meshgen.py index f7e741a90..a8da89b1f 100644 --- a/pychunkedgraph/meshing/meshgen.py +++ b/pychunkedgraph/meshing/meshgen.py @@ -1,312 +1,70 @@ -from pychunkedgraph.utils.general import redis_job +# pylint: disable=invalid-name, missing-docstring, too-many-lines, wrong-import-order, import-outside-toplevel, no-member, c-extension-no-member + from typing import Sequence -import sys import os import numpy as np -import json import time import collections from functools import lru_cache import datetime import pytz -import cloudvolume -from scipy import ndimage, sparse -import networkx as nx +from scipy import ndimage from multiwrapper import multiprocessing_utils as mu -from cloudvolume import Storage, EmptyVolumeException -from cloudvolume.lib import Vec +from cloudfiles import CloudFiles +from cloudvolume import CloudVolume +from cloudvolume.datasource.precomputed.sharding import ShardingSpecification import DracoPy import zmesh import fastremap -import time - -sys.path.insert(0, os.path.join(sys.path[0], "../..")) -os.environ["TRAVIS_BRANCH"] = "IDONTKNOWWHYINEEDTHIS" -UTC = pytz.UTC -from pychunkedgraph.backend import chunkedgraph # noqa -from pychunkedgraph.backend.utils import serializers, column_keys # noqa +from pychunkedgraph.graph.chunkedgraph import ChunkedGraph # noqa +from pychunkedgraph.graph import attributes # noqa from pychunkedgraph.meshing import meshgen_utils # noqa +from pychunkedgraph.meshing.manifest.cache import ManifestCache + + +UTC = pytz.UTC # Change below to true if debugging and want to see results in stdout PRINT_FOR_DEBUGGING = False # Change below to false if debugging and do not need to write to cloud (warning: do not deploy w/ below set to false) WRITING_TO_CLOUD = True +REDIS_HOST = os.environ.get("REDIS_SERVICE_HOST", "localhost") +REDIS_PORT = os.environ.get("REDIS_SERVICE_PORT", "6379") +REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD", "dev") +REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0" + def decode_draco_mesh_buffer(fragment): try: mesh_object = DracoPy.decode_buffer_to_mesh(fragment) vertices = np.array(mesh_object.points) faces = np.array(mesh_object.faces) - except ValueError: - raise ValueError("Not a valid draco mesh") + except ValueError as exc: + raise ValueError("Not a valid draco mesh") from exc - assert len(vertices) % 3 == 0, "Draco mesh vertices not 3-D" - num_vertices = len(vertices) // 3 + num_vertices = len(vertices) # For now, just return this dict until we figure out # how exactly to deal with Draco's lossiness/duplicate vertices return { "num_vertices": num_vertices, - "vertices": vertices.reshape(num_vertices, 3), + "vertices": vertices, "faces": faces, "encoding_options": mesh_object.encoding_options, "encoding_type": "draco", } -@lru_cache(maxsize=None) -def get_l2_remapping(cg, chunk_id, time_stamp): - """ Retrieves l2 node id to sv id mappping - - :param cg: chunkedgraph object - :param chunk_id: np.uint64 - :param time_stamp: datetime object - :return: dictionary - """ - rr_chunk = cg.range_read_chunk( - chunk_id=chunk_id, columns=column_keys.Hierarchy.Child, time_stamp=time_stamp - ) - - # This for-loop ensures that only the latest l2_ids are considered - # The order by id guarantees the time order (only true for same neurons - # but that is the case here). - l2_remapping = {} - all_sv_ids = set() - for (k, row) in rr_chunk.items(): - this_sv_ids = row[0].value - - if this_sv_ids[0] in all_sv_ids: - continue - - all_sv_ids = all_sv_ids.union(set(list(this_sv_ids))) - l2_remapping[k] = this_sv_ids - - return l2_remapping - - -@lru_cache(maxsize=None) -def get_root_l2_remapping(cg, chunk_id, stop_layer, time_stamp, n_threads=4): - """ Retrieves root to l2 node id mapping - - :param cg: chunkedgraph object - :param chunk_id: np.uint64 - :param stop_layer: int - :param time_stamp: datetime object - :return: multiples - """ - - def _get_root_ids(args): - start_id, end_id = args - - root_ids[start_id:end_id] = cg.get_roots(l2_ids[start_id:end_id]) - - l2_id_remap = get_l2_remapping(cg, chunk_id, time_stamp=time_stamp) - - l2_ids = np.array(list(l2_id_remap.keys())) - - root_ids = np.zeros(len(l2_ids), dtype=np.uint64) - n_jobs = np.min([n_threads, len(l2_ids)]) - multi_args = [] - start_ids = np.linspace(0, len(l2_ids), n_jobs + 1).astype(np.int) - for i_block in range(n_jobs): - multi_args.append([start_ids[i_block], start_ids[i_block + 1]]) - - if n_jobs > 0: - mu.multithread_func(_get_root_ids, multi_args, n_threads=n_threads) - - return l2_ids, root_ids, l2_id_remap - - -# @lru_cache(maxsize=None) -def get_l2_overlapping_remappings(cg, chunk_id, time_stamp=None, n_threads=1): - """ Retrieves sv id to l2 id mapping for chunk with overlap in positive - direction (one chunk) - - :param cg: chunkedgraph object - :param chunk_id: np.uint64 - :param time_stamp: datetime object - :return: multiples - """ - if time_stamp is None: - time_stamp = datetime.datetime.utcnow() - - if time_stamp.tzinfo is None: - time_stamp = UTC.localize(time_stamp) - - chunk_coords = cg.get_chunk_coordinates(chunk_id) - chunk_layer = cg.get_chunk_layer(chunk_id) - - neigh_chunk_ids = [] - neigh_parent_chunk_ids = [] - - # Collect neighboring chunks and their parent chunk ids - # We only need to know about the parent chunk ids to figure the lowest - # common chunk - # Notice that the first neigh_chunk_id is equal to `chunk_id`. - for x in range(chunk_coords[0], chunk_coords[0] + 2): - for y in range(chunk_coords[1], chunk_coords[1] + 2): - for z in range(chunk_coords[2], chunk_coords[2] + 2): - - # Chunk id - neigh_chunk_id = cg.get_chunk_id(x=x, y=y, z=z, layer=chunk_layer) - neigh_chunk_ids.append(neigh_chunk_id) - - # Get parent chunk ids - parent_chunk_ids = cg.get_parent_chunk_ids(neigh_chunk_id) - neigh_parent_chunk_ids.append(parent_chunk_ids) - - # Find lowest common chunk - neigh_parent_chunk_ids = np.array(neigh_parent_chunk_ids) - layer_agreement = np.all( - (neigh_parent_chunk_ids - neigh_parent_chunk_ids[0]) == 0, axis=0 - ) - stop_layer = np.where(layer_agreement)[0][0] + 1 - - # Find the parent in the lowest common chunk for each l2 id. These parent - # ids are referred to as root ids even though they are not necessarily the - # root id. - neigh_l2_ids = [] - neigh_l2_id_remap = {} - neigh_root_ids = [] - - safe_l2_ids = [] - unsafe_l2_ids = [] - unsafe_root_ids = [] - - # This loop is the main bottleneck - for neigh_chunk_id in neigh_chunk_ids: - print(neigh_chunk_id, "--------------") - before_time = time.time() - - l2_ids, root_ids, l2_id_remap = get_root_l2_remapping( - cg, neigh_chunk_id, stop_layer, time_stamp=time_stamp, n_threads=n_threads - ) - print("get_root_l2_remapping time", time.time() - before_time) - neigh_l2_ids.extend(l2_ids) - neigh_l2_id_remap.update(l2_id_remap) - neigh_root_ids.extend(root_ids) - - if neigh_chunk_id == chunk_id: - # The first neigh_chunk_id is the one we are interested in. All l2 ids - # that share no root id with any other l2 id are "safe", meaning that - # we can easily obtain the complete remapping (including overlap) for these. - # All other ones have to be resolved using the segmentation. - _, u_idx, c_root_ids = np.unique( - neigh_root_ids, return_counts=True, return_index=True - ) - - safe_l2_ids = l2_ids[u_idx[c_root_ids == 1]] - unsafe_l2_ids = l2_ids[~np.in1d(l2_ids, safe_l2_ids)] - unsafe_root_ids = np.unique(root_ids[u_idx[c_root_ids != 1]]) - - l2_root_dict = dict(zip(neigh_l2_ids, neigh_root_ids)) - root_l2_dict = collections.defaultdict(list) - - # Future sv id -> l2 mapping - sv_ids = [] - l2_ids_flat = [] - - # Do safe ones first - for i_root_id in range(len(neigh_root_ids)): - root_l2_dict[neigh_root_ids[i_root_id]].append(neigh_l2_ids[i_root_id]) - - for l2_id in safe_l2_ids: - root_id = l2_root_dict[l2_id] - for neigh_l2_id in root_l2_dict[root_id]: - l2_sv_ids = neigh_l2_id_remap[neigh_l2_id] - sv_ids.extend(l2_sv_ids) - l2_ids_flat.extend([l2_id] * len(neigh_l2_id_remap[neigh_l2_id])) - - # For the unsafe ones we can only do the in chunk svs - # But we will map the out of chunk svs to the root id and store the - # hierarchical information in a dictionary - for l2_id in unsafe_l2_ids: - sv_ids.extend(neigh_l2_id_remap[l2_id]) - l2_ids_flat.extend([l2_id] * len(neigh_l2_id_remap[l2_id])) - - unsafe_dict = collections.defaultdict(list) - for root_id in unsafe_root_ids: - if np.sum(~np.in1d(root_l2_dict[root_id], unsafe_l2_ids)) == 0: - continue - - for neigh_l2_id in root_l2_dict[root_id]: - unsafe_dict[root_id].append(neigh_l2_id) - - if neigh_l2_id in unsafe_l2_ids: - continue - - sv_ids.extend(neigh_l2_id_remap[neigh_l2_id]) - l2_ids_flat.extend([root_id] * len(neigh_l2_id_remap[neigh_l2_id])) - - # Combine the lists for a (chunk-) global remapping - sv_remapping = dict(zip(sv_ids, l2_ids_flat)) - - return sv_remapping, unsafe_dict - - -def get_remapped_segmentation( - cg, chunk_id, mip=2, overlap_vx=1, time_stamp=None, n_threads=1 -): - """ Downloads + remaps ws segmentation + resolve unclear cases - - :param cg: chunkedgraph object - :param chunk_id: np.uint64 - :param mip: int - :param overlap_vx: int - :param time_stamp: - :return: remapped segmentation - """ - - def _remap(a): - if a in sv_remapping: - return sv_remapping[a] - else: - return 0 - - assert mip >= cg.cv.mip - - sv_remapping, unsafe_dict = get_lx_overlapping_remappings( - cg, chunk_id, time_stamp=time_stamp, n_threads=n_threads - ) - - cv = cloudvolume.CloudVolume(cg.cv.cloudpath, mip=mip) - mip_diff = mip - cg.cv.mip - - mip_chunk_size = cg.chunk_size.astype(np.int) / np.array( - [2 ** mip_diff, 2 ** mip_diff, 1] - ) - mip_chunk_size = mip_chunk_size.astype(np.int) - - chunk_start = ( - cg.cv.mip_voxel_offset(mip) - + cg.get_chunk_coordinates(chunk_id) * mip_chunk_size - ) - chunk_end = chunk_start + mip_chunk_size + overlap_vx - chunk_end = Vec.clamp( - chunk_end, - cg.cv.mip_voxel_offset(mip), - cg.cv.mip_voxel_offset(mip) + cg.cv.mip_volume_size(mip), - ) - - ws_seg = cv[ - chunk_start[0] : chunk_end[0], - chunk_start[1] : chunk_end[1], - chunk_start[2] : chunk_end[2], - ].squeeze() - - _remap_vec = np.vectorize(_remap) - seg = _remap_vec(ws_seg).astype(np.uint64) - +def remap_seg_using_unsafe_dict(seg, unsafe_dict): for unsafe_root_id in unsafe_dict.keys(): bin_seg = seg == unsafe_root_id if np.sum(bin_seg) == 0: continue - l2_edges = [] cc_seg, n_cc = ndimage.label(bin_seg) for i_cc in range(1, n_cc + 1): bin_cc_seg = cc_seg == i_cc @@ -321,31 +79,37 @@ def _remap(a): if len(linked_l2_ids) == 0: seg[bin_cc_seg] = 0 - elif len(linked_l2_ids) == 1: - seg[bin_cc_seg] = linked_l2_ids[0] else: seg[bin_cc_seg] = linked_l2_ids[0] - for i_l2_id in range(len(linked_l2_ids) - 1): - for j_l2_id in range(i_l2_id + 1, len(linked_l2_ids)): - l2_edges.append( - [linked_l2_ids[i_l2_id], linked_l2_ids[j_l2_id]] - ) + return seg - if len(l2_edges) > 0: - g = nx.Graph() - g.add_edges_from(l2_edges) - ccs = nx.connected_components(g) +def get_remapped_segmentation( + cg, chunk_id, mip=2, overlap_vx=1, time_stamp=None, n_threads=1 +): + """Downloads + remaps ws segmentation + resolve unclear cases - for cc in ccs: - cc_ids = np.sort(list(cc)) - seg[np.in1d(seg, cc_ids[1:]).reshape(seg.shape)] = cc_ids[0] + :param cg: chunkedgraph object + :param chunk_id: np.uint64 + :param mip: int + :param overlap_vx: int + :param time_stamp: + :return: remapped segmentation + """ + assert mip >= cg.meta.cv.mip - return seg + sv_remapping, unsafe_dict = get_lx_overlapping_remappings( + cg, chunk_id, time_stamp=time_stamp, n_threads=n_threads + ) + + ws_seg = meshgen_utils.get_ws_seg_for_chunk(cg, chunk_id, mip, overlap_vx) + seg = fastremap.mask_except(ws_seg, list(sv_remapping.keys()), in_place=False) + fastremap.remap(seg, sv_remapping, preserve_missing_labels=True, in_place=True) + + return remap_seg_using_unsafe_dict(seg, unsafe_dict) -# TODO: refactor (duplicated code with get_remapped_segmentation) def get_remapped_seg_for_lvl2_nodes( cg, chunk_id: np.uint64, @@ -355,7 +119,8 @@ def get_remapped_seg_for_lvl2_nodes( time_stamp=None, n_threads: int = 1, ): - """ Downloads + remaps ws segmentation + resolve unclear cases, filter out all but specified lvl2_nodes + """Downloads + remaps ws segmentation + resolve unclear cases, + filter out all but specified lvl2_nodes :param cg: chunkedgraph object :param chunk_id: np.uint64 @@ -364,32 +129,7 @@ def get_remapped_seg_for_lvl2_nodes( :param time_stamp: :return: remapped segmentation """ - # Determine the segmentation bounding box to download given cg, chunk_id, and mip. Then download - cv = cloudvolume.CloudVolume(cg.cv.cloudpath, mip=mip) - mip_diff = mip - cg.cv.mip - - mip_chunk_size = cg.chunk_size.astype(np.int) / np.array( - [2 ** mip_diff, 2 ** mip_diff, 1] - ) - mip_chunk_size = mip_chunk_size.astype(np.int) - - chunk_start = ( - cg.cv.mip_voxel_offset(mip) - + cg.get_chunk_coordinates(chunk_id) * mip_chunk_size - ) - chunk_end = chunk_start + mip_chunk_size + overlap_vx - chunk_end = Vec.clamp( - chunk_end, - cg.cv.mip_voxel_offset(mip), - cg.cv.mip_voxel_offset(mip) + cg.cv.mip_volume_size(mip), - ) - - seg = cv[ - chunk_start[0] : chunk_end[0], - chunk_start[1] : chunk_end[1], - chunk_start[2] : chunk_end[2], - ].squeeze() - + seg = meshgen_utils.get_ws_seg_for_chunk(cg, chunk_id, mip, overlap_vx) sv_of_lvl2_nodes = cg.get_children(lvl2_nodes) # Check which of the lvl2_nodes meet the chunk boundary @@ -399,7 +139,8 @@ def get_remapped_seg_for_lvl2_nodes( node_on_the_border = False for sv_id in sv_list: remapping[sv_id] = node - # If a node_id is on the chunk_boundary, we must check the overlap region to see if the meshes' end will be open or closed + # If a node_id is on the chunk_boundary, we must check + # the overlap region to see if the meshes' end will be open or closed if (not node_on_the_border) and ( np.isin(sv_id, seg[-2, :, :]) or np.isin(sv_id, seg[:, -2, :]) @@ -423,51 +164,13 @@ def get_remapped_seg_for_lvl2_nodes( sv_remapping.update(remapping) fastremap.mask_except(seg, list(sv_remapping.keys()), in_place=True) fastremap.remap(seg, sv_remapping, preserve_missing_labels=True, in_place=True) - # For some supervoxel, they could map to multiple l2 nodes in the chunk, so we must perform a connected component analysis + # For some supervoxel, they could map to multiple l2 nodes in the chunk, + # so we must perform a connected component analysis # to see which l2 node they are adjacent to - for unsafe_root_id in unsafe_dict.keys(): - bin_seg = seg == unsafe_root_id - - if np.sum(bin_seg) == 0: - continue - - l2_edges = [] - cc_seg, n_cc = ndimage.label(bin_seg) - for i_cc in range(1, n_cc + 1): - bin_cc_seg = cc_seg == i_cc - - overlaps = [] - overlaps.extend(np.unique(seg[-2, :, :][bin_cc_seg[-1, :, :]])) - overlaps.extend(np.unique(seg[:, -2, :][bin_cc_seg[:, -1, :]])) - overlaps.extend(np.unique(seg[:, :, -2][bin_cc_seg[:, :, -1]])) - overlaps = np.unique(overlaps) - - linked_l2_ids = overlaps[np.in1d(overlaps, unsafe_dict[unsafe_root_id])] - - if len(linked_l2_ids) == 0: - seg[bin_cc_seg] = 0 - elif len(linked_l2_ids) == 1: - seg[bin_cc_seg] = linked_l2_ids[0] - else: - seg[bin_cc_seg] = linked_l2_ids[0] - - for i_l2_id in range(len(linked_l2_ids) - 1): - for j_l2_id in range(i_l2_id + 1, len(linked_l2_ids)): - l2_edges.append( - [linked_l2_ids[i_l2_id], linked_l2_ids[j_l2_id]] - ) - - if len(l2_edges) > 0: - g = nx.Graph() - g.add_edges_from(l2_edges) - - ccs = nx.connected_components(g) - - for cc in ccs: - cc_ids = np.sort(list(cc)) - seg[np.in1d(seg, cc_ids[1:]).reshape(seg.shape)] = cc_ids[0] + return remap_seg_using_unsafe_dict(seg, unsafe_dict) else: - # If no nodes in our subset meet the chunk boundary we can simply retrieve the sv of the nodes in the subset + # If no nodes in our subset meet the chunk boundary + # we can simply retrieve the sv of the nodes in the subset fastremap.mask_except(seg, list(remapping.keys()), in_place=True) fastremap.remap(seg, remapping, preserve_missing_labels=True, in_place=True) @@ -476,7 +179,7 @@ def get_remapped_seg_for_lvl2_nodes( @lru_cache(maxsize=None) def get_higher_to_lower_remapping(cg, chunk_id, time_stamp): - """ Retrieves lx node id to sv id mappping + """Retrieves lx node id to sv id mappping :param cg: chunkedgraph object :param chunk_id: np.uint64 @@ -488,7 +191,7 @@ def _lower_remaps(ks): return np.concatenate([lower_remaps[k] for k in ks]) assert cg.get_chunk_layer(chunk_id) >= 2 - assert cg.get_chunk_layer(chunk_id) <= cg.n_layers + assert cg.get_chunk_layer(chunk_id) <= cg.meta.layer_count print(f"\n{chunk_id} ----------------\n") @@ -501,7 +204,7 @@ def _lower_remaps(ks): ) rr_chunk = cg.range_read_chunk( - chunk_id=chunk_id, columns=column_keys.Hierarchy.Child, time_stamp=time_stamp + chunk_id=chunk_id, properties=attributes.Hierarchy.Child, time_stamp=time_stamp ) # This for-loop ensures that only the latest lx_ids are considered @@ -511,7 +214,6 @@ def _lower_remaps(ks): all_lower_ids = set() for k in sorted(rr_chunk.keys(), reverse=True): this_child_ids = rr_chunk[k][0].value - if this_child_ids[0] in all_lower_ids: continue @@ -532,7 +234,7 @@ def _lower_remaps(ks): @lru_cache(maxsize=None) def get_root_lx_remapping(cg, chunk_id, stop_layer, time_stamp, n_threads=1): - """ Retrieves root to l2 node id mapping + """Retrieves root to l2 node id mapping :param cg: chunkedgraph object :param chunk_id: np.uint64 @@ -544,7 +246,9 @@ def get_root_lx_remapping(cg, chunk_id, stop_layer, time_stamp, n_threads=1): def _get_root_ids(args): start_id, end_id = args root_ids[start_id:end_id] = cg.get_roots( - lx_ids[start_id:end_id], stop_layer=stop_layer + lx_ids[start_id:end_id], + stop_layer=stop_layer, + fail_to_zero=True, ) lx_id_remap = get_higher_to_lower_remapping(cg, chunk_id, time_stamp=time_stamp) @@ -554,7 +258,7 @@ def _get_root_ids(args): root_ids = np.zeros(len(lx_ids), dtype=np.uint64) n_jobs = np.min([n_threads, len(lx_ids)]) multi_args = [] - start_ids = np.linspace(0, len(lx_ids), n_jobs + 1).astype(np.int) + start_ids = np.linspace(0, len(lx_ids), n_jobs + 1).astype(int) for i_block in range(n_jobs): multi_args.append([start_ids[i_block], start_ids[i_block + 1]]) @@ -564,22 +268,7 @@ def _get_root_ids(args): return lx_ids, np.array(root_ids), lx_id_remap -# @lru_cache(maxsize=None) -def get_lx_overlapping_remappings(cg, chunk_id, time_stamp=None, n_threads=1): - """ Retrieves sv id to layer mapping for chunk with overlap in positive - direction (one chunk) - - :param cg: chunkedgraph object - :param chunk_id: np.uint64 - :param time_stamp: datetime object - :return: multiples - """ - if time_stamp is None: - time_stamp = datetime.datetime.utcnow() - - if time_stamp.tzinfo is None: - time_stamp = UTC.localize(time_stamp) - +def calculate_stop_layer(cg, chunk_id): chunk_coords = cg.get_chunk_coordinates(chunk_id) chunk_layer = cg.get_chunk_layer(chunk_id) @@ -593,23 +282,46 @@ def get_lx_overlapping_remappings(cg, chunk_id, time_stamp=None, n_threads=1): for x in range(chunk_coords[0], chunk_coords[0] + 2): for y in range(chunk_coords[1], chunk_coords[1] + 2): for z in range(chunk_coords[2], chunk_coords[2] + 2): - # Chunk id - neigh_chunk_id = cg.get_chunk_id(x=x, y=y, z=z, layer=chunk_layer) - neigh_chunk_ids.append(neigh_chunk_id) - - # Get parent chunk ids - parent_chunk_ids = cg.get_parent_chunk_ids(neigh_chunk_id) - neigh_parent_chunk_ids.append(parent_chunk_ids) + try: + neigh_chunk_id = cg.get_chunk_id(x=x, y=y, z=z, layer=chunk_layer) + # Get parent chunk ids + parent_chunk_ids = cg.get_parent_chunk_ids(neigh_chunk_id) + neigh_chunk_ids.append(neigh_chunk_id) + neigh_parent_chunk_ids.append(parent_chunk_ids) + except: + # cg.get_parent_chunk_id can fail if neigh_chunk_id is outside the dataset + # (only happens when cg.meta.bitmasks[chunk_layer+1] == log(max(x,y,z)), + # so only for specific datasets in which the # of chunks in the widest dimension + # just happens to be a power of two) + pass # Find lowest common chunk neigh_parent_chunk_ids = np.array(neigh_parent_chunk_ids) layer_agreement = np.all( (neigh_parent_chunk_ids - neigh_parent_chunk_ids[0]) == 0, axis=0 ) - stop_layer = np.where(layer_agreement)[0][0] + 1 + chunk_layer - # stop_layer = cg.n_layers + stop_layer = np.where(layer_agreement)[0][0] + chunk_layer + + return stop_layer, neigh_chunk_ids + +# @lru_cache(maxsize=None) +def get_lx_overlapping_remappings(cg, chunk_id, time_stamp=None, n_threads=1): + """Retrieves sv id to layer mapping for chunk with overlap in positive + direction (one chunk) + + :param cg: chunkedgraph object + :param chunk_id: np.uint64 + :param time_stamp: datetime object + :return: multiples + """ + if time_stamp is None: + time_stamp = datetime.datetime.utcnow() + if time_stamp.tzinfo is None: + time_stamp = UTC.localize(time_stamp) + + stop_layer, neigh_chunk_ids = calculate_stop_layer(cg, chunk_id) print(f"Stop layer: {stop_layer}") # Find the parent in the lowest common chunk for each l2 id. These parent @@ -655,10 +367,10 @@ def get_lx_overlapping_remappings(cg, chunk_id, time_stamp=None, n_threads=1): sv_ids = [] lx_ids_flat = [] - # Do safe ones first for i_root_id in range(len(neigh_root_ids)): root_lx_dict[neigh_root_ids[i_root_id]].append(neigh_lx_ids[i_root_id]) + # Do safe ones first for lx_id in safe_lx_ids: root_id = lx_root_dict[lx_id] for neigh_lx_id in root_lx_dict[root_id]: @@ -696,7 +408,7 @@ def get_lx_overlapping_remappings(cg, chunk_id, time_stamp=None, n_threads=1): def get_root_remapping_for_nodes_and_svs( cg, chunk_id, node_ids, sv_ids, stop_layer, time_stamp, n_threads=1 ): - """ Retrieves root to node id mapping for specified node ids and supervoxel ids + """Retrieves root to node id mapping for specified node ids and supervoxel ids :param cg: chunkedgraph object :param chunk_id: np.uint64 @@ -710,19 +422,23 @@ def _get_root_ids(args): start_id, end_id = args root_ids[start_id:end_id] = cg.get_roots( - combined_ids[start_id:end_id], stop_layer=stop_layer, time_stamp=time_stamp + combined_ids[start_id:end_id], + stop_layer=stop_layer, + time_stamp=time_stamp, + fail_to_zero=True, ) rr = cg.range_read_chunk( - chunk_id=chunk_id, columns=column_keys.Hierarchy.Parent, time_stamp=time_stamp + chunk_id=chunk_id, properties=attributes.Hierarchy.Child, time_stamp=time_stamp ) - upper_lvl_ids = [id[0].value for id in rr.values()] - combined_ids = np.concatenate((node_ids, sv_ids, upper_lvl_ids)) + chunk_sv_ids = np.unique(np.concatenate([id[0].value for id in rr.values()])) + chunk_l2_ids = np.unique(cg.get_parents(chunk_sv_ids, time_stamp=time_stamp)) + combined_ids = np.concatenate((node_ids, sv_ids, chunk_l2_ids)) root_ids = np.zeros(len(combined_ids), dtype=np.uint64) n_jobs = np.min([n_threads, len(combined_ids)]) multi_args = [] - start_ids = np.linspace(0, len(combined_ids), n_jobs + 1).astype(np.int) + start_ids = np.linspace(0, len(combined_ids), n_jobs + 1).astype(int) for i_block in range(n_jobs): multi_args.append([start_ids[i_block], start_ids[i_block + 1]]) @@ -747,7 +463,7 @@ def get_lx_overlapping_remappings_for_nodes_and_svs( time_stamp=None, n_threads: int = 1, ): - """ Retrieves sv id to layer mapping for chunk with overlap in positive + """Retrieves sv id to layer mapping for chunk with overlap in positive direction (one chunk) :param cg: chunkedgraph object @@ -760,40 +476,10 @@ def get_lx_overlapping_remappings_for_nodes_and_svs( """ if time_stamp is None: time_stamp = datetime.datetime.utcnow() - if time_stamp.tzinfo is None: time_stamp = UTC.localize(time_stamp) - chunk_coords = cg.get_chunk_coordinates(chunk_id) - chunk_layer = cg.get_chunk_layer(chunk_id) - - neigh_chunk_ids = [] - neigh_parent_chunk_ids = [] - - # Collect neighboring chunks and their parent chunk ids - # We only need to know about the parent chunk ids to figure the lowest - # common chunk - # Notice that the first neigh_chunk_id is equal to `chunk_id`. - for x in range(chunk_coords[0], chunk_coords[0] + 2): - for y in range(chunk_coords[1], chunk_coords[1] + 2): - for z in range(chunk_coords[2], chunk_coords[2] + 2): - - # Chunk id - neigh_chunk_id = cg.get_chunk_id(x=x, y=y, z=z, layer=chunk_layer) - neigh_chunk_ids.append(neigh_chunk_id) - - # Get parent chunk ids - parent_chunk_ids = cg.get_parent_chunk_ids(neigh_chunk_id) - neigh_parent_chunk_ids.append(parent_chunk_ids) - - # Find lowest common chunk - neigh_parent_chunk_ids = np.array(neigh_parent_chunk_ids) - layer_agreement = np.all( - (neigh_parent_chunk_ids - neigh_parent_chunk_ids[0]) == 0, axis=0 - ) - stop_layer = np.where(layer_agreement)[0][0] + 1 + chunk_layer - # stop_layer = cg.n_layers - + stop_layer, _ = calculate_stop_layer(cg, chunk_id) print(f"Stop layer: {stop_layer}") # Find the parent in the lowest common chunk for each node id and sv id. These parent @@ -803,16 +489,19 @@ def get_lx_overlapping_remappings_for_nodes_and_svs( cg, chunk_id, node_ids, sv_ids, stop_layer, time_stamp, n_threads ) - u_root_ids, c_root_ids = np.unique(chunks_root_ids, return_counts=True) + u_root_ids, u_idx, c_root_ids = np.unique( + chunks_root_ids, return_counts=True, return_index=True + ) # All l2 ids that share no root id with any other l2 id in the chunk are "safe", meaning # that we can easily obtain the complete remapping (including # overlap) for these. All other ones have to be resolved using the # segmentation. - temp_node_roots = u_root_ids[np.where(u_root_ids == node_root_ids)] - node_root_counts = c_root_ids[np.where(u_root_ids == node_root_ids)] - unsafe_root_ids = temp_node_roots[np.where(node_root_counts > 1)] + root_sorted_idx = np.argsort(u_root_ids) + node_sorted_index = np.searchsorted(u_root_ids[root_sorted_idx], node_root_ids) + node_root_counts = c_root_ids[root_sorted_idx][node_sorted_index] + unsafe_root_ids = node_root_ids[np.where(node_root_counts > 1)] safe_node_ids = node_ids[~np.isin(node_root_ids, unsafe_root_ids)] node_to_root_dict = dict(zip(node_ids, node_root_ids)) @@ -837,7 +526,7 @@ def get_lx_overlapping_remappings_for_nodes_and_svs( if len(sv_ids_to_add) > 0: relevant_node_ids = node_ids[np.where(node_root_ids == root_id)] if len(relevant_node_ids) > 0: - unsafe_dict[root_id].append(relevant_node_ids) + unsafe_dict[root_id].extend(relevant_node_ids) sv_ids_to_remap.extend(sv_ids_to_add) node_ids_flat.extend([root_id] * len(sv_ids_to_add)) @@ -847,28 +536,8 @@ def get_lx_overlapping_remappings_for_nodes_and_svs( return sv_remapping, unsafe_dict -def merge_meshes(meshes): - vertexct = np.zeros(len(meshes) + 1, np.uint32) - vertexct[1:] = np.cumsum([x["num_vertices"] for x in meshes]) - vertices = np.concatenate([x["vertices"] for x in meshes]) - faces = np.concatenate( - [mesh["faces"] + vertexct[i] for i, mesh in enumerate(meshes)] - ) - - if vertexct[-1] > 0: - # Remove duplicate vertices - vertices, faces = np.unique(vertices[faces], return_inverse=True, axis=0) - faces = faces.astype(np.uint32) - - return { - "num_vertices": np.uint32(len(vertices)), - "vertices": vertices, - "faces": faces, - } - - def get_meshing_necessities_from_graph(cg, chunk_id: np.uint64, mip: int): - """ Given a chunkedgraph, chunk_id, and mip level, return the voxel dimensions of the chunk to be meshed (mesh_block_shape) + """Given a chunkedgraph, chunk_id, and mip level, return the voxel dimensions of the chunk to be meshed (mesh_block_shape) and the chunk origin in the dataset in nm. :param cg: chunkedgraph instance @@ -878,9 +547,10 @@ def get_meshing_necessities_from_graph(cg, chunk_id: np.uint64, mip: int): layer = cg.get_chunk_layer(chunk_id) cx, cy, cz = cg.get_chunk_coordinates(chunk_id) mesh_block_shape = meshgen_utils.get_mesh_block_shape_for_mip(cg, layer, mip) + voxel_resolution = cg.meta.cv.mip_resolution(mip) chunk_offset = ( - (cx, cy, cz) * mesh_block_shape + cg.cv.mip_voxel_offset(mip) - ) * cg.cv.mip_resolution(mip) + (cx, cy, cz) * mesh_block_shape + cg.meta.cv.mip_voxel_offset(mip) + ) * voxel_resolution return layer, mesh_block_shape, chunk_offset @@ -891,7 +561,7 @@ def calculate_quantization_bits_and_range( draco_quantization_bits = np.ceil( np.log2(min_quantization_range / max_draco_bin_size + 1) ) - num_draco_bins = 2 ** draco_quantization_bits - 1 + num_draco_bins = 2**draco_quantization_bits - 1 draco_bin_size = np.ceil(min_quantization_range / num_draco_bins) draco_quantization_range = draco_bin_size * num_draco_bins if draco_quantization_range < min_quantization_range + draco_bin_size: @@ -905,11 +575,10 @@ def calculate_quantization_bits_and_range( return draco_quantization_bits, draco_quantization_range, draco_bin_size -# TODO: Bring over meshing readme from macastro-fafb-ingest-draco branch def get_draco_encoding_settings_for_chunk( cg, chunk_id: np.uint64, mip: int = 2, high_padding: int = 1 ): - """ Calculate the proper draco encoding settings for a chunk to ensure proper stitching is possible + """Calculate the proper draco encoding settings for a chunk to ensure proper stitching is possible on the layer above. For details about how and why we do this, please see the meshing Readme :param cg: chunkedgraph instance @@ -920,12 +589,16 @@ def get_draco_encoding_settings_for_chunk( _, mesh_block_shape, chunk_offset = get_meshing_necessities_from_graph( cg, chunk_id, mip ) - segmentation_resolution = cg.cv.scales[mip]["resolution"] + segmentation_resolution = cg.meta.cv.mip_resolution(mip) min_quantization_range = max( (mesh_block_shape + high_padding) * segmentation_resolution ) max_draco_bin_size = np.floor(min(segmentation_resolution) / np.sqrt(2)) - draco_quantization_bits, draco_quantization_range, draco_bin_size = calculate_quantization_bits_and_range( + ( + draco_quantization_bits, + draco_quantization_range, + draco_bin_size, + ) = calculate_quantization_bits_and_range( min_quantization_range, max_draco_bin_size ) draco_quantization_origin = chunk_offset - (chunk_offset % draco_bin_size) @@ -947,12 +620,16 @@ def get_next_layer_draco_encoding_settings( _, mesh_block_shape, chunk_offset = get_meshing_necessities_from_graph( cg, next_layer_chunk_id, mip ) - segmentation_resolution = cg.cv.scales[mip]["resolution"] + segmentation_resolution = cg.meta.cv.mip_resolution(mip) min_quantization_range = ( max(mesh_block_shape * segmentation_resolution) + 2 * old_draco_bin_size ) max_draco_bin_size = np.floor(min(segmentation_resolution) / np.sqrt(2)) - draco_quantization_bits, draco_quantization_range, draco_bin_size = calculate_quantization_bits_and_range( + ( + draco_quantization_bits, + draco_quantization_range, + draco_bin_size, + ) = calculate_quantization_bits_and_range( min_quantization_range, max_draco_bin_size ) draco_quantization_origin = ( @@ -1015,7 +692,9 @@ def transform_draco_fragment_and_return_encoding_options( return cur_encoding_settings -def merge_draco_meshes_across_boundaries(cg, fragments, chunk_id, mip, high_padding): +def merge_draco_meshes_across_boundaries( + cg, fragments, chunk_id, mip, high_padding, return_zmesh_object=False +): """ Merge a list of draco mesh fragments, removing duplicate vertices that lie on the chunk boundary where the meshes meet. @@ -1037,14 +716,16 @@ def merge_draco_meshes_across_boundaries(cg, fragments, chunk_id, mip, high_padd _, _, child_chunk_offset = get_meshing_necessities_from_graph( cg, child_chunk_id, mip ) - # Get the draco encoding settings for the child chunk in the "bottom corner" of the chunk_id chunk + # Get the draco encoding settings for the + # child chunk in the "bottom corner" of the chunk_id chunk draco_encoding_settings_smaller_chunk = get_draco_encoding_settings_for_chunk( cg, child_chunk_id, mip=mip, high_padding=high_padding ) draco_bin_size = draco_encoding_settings_smaller_chunk["quantization_range"] / ( 2 ** draco_encoding_settings_smaller_chunk["quantization_bits"] - 1 ) - # Calculate which draco bin the child chunk's boundaries were placed into (for each x,y,z of boundary) + # Calculate which draco bin the child chunk's boundaries + # were placed into (for each x,y,z of boundary) chunk_boundary_bin_index = np.floor( ( child_chunk_offset @@ -1094,6 +775,9 @@ def merge_draco_meshes_across_boundaries(cg, fragments, chunk_id, mip, high_padd # Remap the faces to their new vertex indices fastremap.remap(faces, faces_remapping, in_place=True) + if return_zmesh_object: + return zmesh.Mesh(vertices[:, 0:3], faces.reshape(-1, 3), None) + return { "num_vertices": np.uint32(len(vertices)), "vertices": vertices[:, 0:3].reshape(-1), @@ -1102,7 +786,8 @@ def merge_draco_meshes_across_boundaries(cg, fragments, chunk_id, mip, high_padd def black_out_dust_from_segmentation(seg, dust_threshold): - """ Black out (set to 0) IDs in segmentation not on the segmentation border that have less voxels than dust_threshold + """Black out (set to 0) IDs in segmentation not on the segmentation + border that have less voxels than dust_threshold :param seg: 3D segmentation (usually uint64) :param dust_threshold: int @@ -1129,16 +814,23 @@ def black_out_dust_from_segmentation(seg, dust_threshold): seg = fastremap.mask(seg, dust_segids, in_place=True) +def _get_timestamp_from_node_ids(cg, node_ids): + timestamps = cg.get_node_timestamps(node_ids, return_numpy=False) + return max(timestamps) + datetime.timedelta(milliseconds=1) + + def remeshing( cg, l2_node_ids: Sequence[np.uint64], + cv_sharded_mesh_dir: str, + cv_unsharded_mesh_path: str, stop_layer: int = None, - cv_path: str = None, - cv_mesh_dir: str = None, mip: int = 2, - max_err: int = 320, + max_err: int = 40, + time_stamp: datetime.datetime or None = None, ): - """ Given a chunkedgraph, a list of level 2 nodes, perform remeshing and stitching up the node hierarchy (or up to the stop_layer) + """Given a chunkedgraph, a list of level 2 nodes, + perform remeshing and stitching up the node hierarchy (or up to the stop_layer) :param cg: chunkedgraph instance :param l2_node_ids: list of uint64 @@ -1161,26 +853,33 @@ def add_nodes_to_l2_chunk_dict(ids): for chunk_id, node_ids in l2_chunk_dict.items(): if PRINT_FOR_DEBUGGING: print("remeshing", chunk_id, node_ids) + try: + l2_time_stamp = _get_timestamp_from_node_ids(cg, node_ids) + except ValueError: + # ignore bad/invalid messages + return # Remesh the l2_node_ids - chunk_mesh_task_new_remapping( - cg.get_serialized_info(), + chunk_initial_mesh_task( + None, chunk_id, - cg._cv_path, - cv_mesh_dir=cv_mesh_dir, mip=mip, - fragment_batch_size=20, node_id_subset=node_ids, cg=cg, + cv_unsharded_mesh_path=cv_unsharded_mesh_path, + max_err=max_err, + sharded=False, + time_stamp=l2_time_stamp, ) chunk_dicts = [] max_layer = stop_layer or cg._n_layers for layer in range(3, max_layer + 1): chunk_dicts.append(collections.defaultdict(set)) cur_chunk_dict = l2_chunk_dict - # Find the parents of each l2_node_id up to the stop_layer, as well as their associated chunk_ids + # Find the parents of each l2_node_id up to the stop_layer, + # as well as their associated chunk_ids for layer in range(3, max_layer + 1): for _, node_ids in cur_chunk_dict.items(): - parent_nodes = cg.get_parents(node_ids) + parent_nodes = cg.get_parents(node_ids, time_stamp=time_stamp) for parent_node in parent_nodes: chunk_layer = cg.get_chunk_layer(parent_node) index_in_dict_array = chunk_layer - 3 @@ -1193,314 +892,476 @@ def add_nodes_to_l2_chunk_dict(ids): if PRINT_FOR_DEBUGGING: print("remeshing", chunk_id, node_ids) # Stitch the meshes of the parents we found in the previous loop - chunk_mesh_task_new_remapping( - cg.get_serialized_info(), + chunk_stitch_remeshing_task( + None, chunk_id, - cg._cv_path, - cv_mesh_dir=cv_mesh_dir, mip=mip, - fragment_batch_size=20, + fragment_batch_size=40, node_id_subset=node_ids, cg=cg, + cv_sharded_mesh_dir=cv_sharded_mesh_dir, + cv_unsharded_mesh_path=cv_unsharded_mesh_path, ) -REDIS_HOST = os.environ.get("REDIS_SERVICE_HOST", "localhost") -REDIS_PORT = os.environ.get("REDIS_SERVICE_PORT", "6379") -REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD", "dev") -REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0" - -# @redis_job(REDIS_URL, 'mesh_frag_test_channel') -# TODO: refactor this bloated function -def chunk_mesh_task_new_remapping( - cg_info, +def chunk_initial_mesh_task( + cg_name, chunk_id, - cv_path, - cv_mesh_dir=None, + cv_unsharded_mesh_path, mip=2, - max_err=320, - base_layer=2, + max_err=40, lod=0, encoding="draco", time_stamp=None, dust_threshold=None, return_frag_count=False, - fragment_batch_size=None, node_id_subset=None, cg=None, + sharded=False, + cache=True, ): if cg is None: - cg = chunkedgraph.ChunkedGraph(**cg_info) - mesh_dir = cv_mesh_dir or cg._mesh_dir + cg = ChunkedGraph(graph_id=cg_name) result = [] + cache_string = "public" if cache else "no-cache" layer, _, chunk_offset = get_meshing_necessities_from_graph(cg, chunk_id, mip) cx, cy, cz = cg.get_chunk_coordinates(chunk_id) high_padding = 1 - if layer <= 2: - assert mip >= cg.cv.mip + assert layer == 2 + assert mip >= cg.meta.cv.mip - result.append((chunk_id, layer, cx, cy, cz)) - print( - "Retrieving remap table for chunk %s -- (%s, %s, %s, %s)" - % (chunk_id, layer, cx, cy, cz) + if sharded: + cv = CloudVolume( + f"graphene://https://localhost/segmentation/table/dummy", + info=meshgen_utils.get_json_info(cg), ) - mesher = zmesh.Mesher(cg.cv.mip_resolution(mip)) - draco_encoding_settings = get_draco_encoding_settings_for_chunk( - cg, chunk_id, mip, high_padding + sharding_info = cv.mesh.meta.info["sharding"]["2"] + sharding_spec = ShardingSpecification.from_dict(sharding_info) + merged_meshes = {} + mesh_dst = os.path.join( + cv.cloudpath, cv.mesh.meta.mesh_path, "initial", str(layer) ) - if node_id_subset is None: - seg = get_remapped_segmentation( - cg, chunk_id, mip, overlap_vx=high_padding, time_stamp=time_stamp - ) - else: - seg = get_remapped_seg_for_lvl2_nodes( - cg, - chunk_id, - node_id_subset, - mip=mip, - overlap_vx=high_padding, - time_stamp=time_stamp, - ) - if dust_threshold: - black_out_dust_from_segmentation(seg, dust_threshold) - if return_frag_count: - return np.unique(seg).shape[0] - mesher.mesh(seg.T) - del seg - with Storage(cv_path) as storage: - if PRINT_FOR_DEBUGGING: - print("cv path", cv_path) - print("mesh_dir", mesh_dir) - print("num ids", len(mesher.ids())) - result.append(len(mesher.ids())) - for obj_id in mesher.ids(): - mesh = mesher.get_mesh( - obj_id, - simplification_factor=999999, - max_simplification_error=max_err, - ) - mesher.erase(obj_id) - mesh.vertices[:] += chunk_offset - if encoding == "draco": - try: - file_contents = DracoPy.encode_mesh_to_buffer( - mesh.vertices.flatten("C"), - mesh.faces.flatten("C"), - **draco_encoding_settings, - ) - except: - result.append( - f"{obj_id} failed: {len(mesh.vertices)} vertices, {len(mesh.faces)} faces" - ) - continue - compress = False - else: - file_contents = mesh.to_precomputed() - compress = True - if WRITING_TO_CLOUD: - storage.put_file( - file_path=f"{mesh_dir}/{meshgen_utils.get_mesh_name(cg, obj_id)}", - content=file_contents, - compress=compress, - cache_control="no-cache", - ) else: - # For each node with more than one child, create a new fragment by - # merging the mesh fragments of the children. + mesh_dst = cv_unsharded_mesh_path - print( - "Retrieving children for chunk %s -- (%s, %s, %s, %s)" - % (chunk_id, layer, cx, cy, cz) + result.append((chunk_id, layer, cx, cy, cz)) + print( + "Retrieving remap table for chunk %s -- (%s, %s, %s, %s)" + % (chunk_id, layer, cx, cy, cz) + ) + mesher = zmesh.Mesher(cg.meta.cv.mip_resolution(mip)) + draco_encoding_settings = get_draco_encoding_settings_for_chunk( + cg, chunk_id, mip, high_padding + ) + if node_id_subset is None: + seg = get_remapped_segmentation( + cg, chunk_id, mip, overlap_vx=high_padding, time_stamp=time_stamp ) - if node_id_subset is None: - range_read = cg.range_read_chunk( - layer, cx, cy, cz, columns=column_keys.Hierarchy.Child - ) + else: + seg = get_remapped_seg_for_lvl2_nodes( + cg, + chunk_id, + node_id_subset, + mip=mip, + overlap_vx=high_padding, + time_stamp=time_stamp, + ) + if dust_threshold: + black_out_dust_from_segmentation(seg, dust_threshold) + if return_frag_count: + return np.unique(seg).shape[0] + mesher.mesh(seg) + del seg + cf = CloudFiles(mesh_dst) + if PRINT_FOR_DEBUGGING: + print("cv path", mesh_dst) + print("num ids", len(mesher.ids())) + result.append(len(mesher.ids())) + for obj_id in mesher.ids(): + mesh = mesher.get(obj_id, reduction_factor=100, max_error=max_err) + mesher.erase(obj_id) + mesh.vertices[:] += chunk_offset + if encoding == "draco": + try: + file_contents = DracoPy.encode_mesh_to_buffer( + mesh.vertices.flatten("C"), + mesh.faces.flatten("C"), + **draco_encoding_settings, + ) + except: + result.append( + f"{obj_id} failed: {len(mesh.vertices)} vertices, {len(mesh.faces)} faces" + ) + continue + compress = False else: - range_read = cg.read_node_id_rows( - node_ids=node_id_subset, columns=column_keys.Hierarchy.Child - ) + file_contents = mesh.to_precomputed() + compress = True + if WRITING_TO_CLOUD: + if sharded: + merged_meshes[int(obj_id)] = file_contents + else: + cf.put( + path=f"{meshgen_utils.get_mesh_name(cg, obj_id)}", + content=file_contents, + compress=compress, + cache_control=cache_string, + ) + if sharded and WRITING_TO_CLOUD: + shard_binary = sharding_spec.synthesize_shard(merged_meshes) + shard_filename = cv.mesh.readers[layer].get_filename(chunk_id) + cf.put( + shard_filename, + shard_binary, + content_type="application/octet-stream", + compress=False, + cache_control=cache_string, + ) + if PRINT_FOR_DEBUGGING: + print(", ".join(str(x) for x in result)) + return result - print("Collecting only nodes with more than one child: ", end="") - node_ids = np.array(list(range_read.keys())) - node_rows = np.array(list(range_read.values())) - child_fragments = np.array( - [ - fragment.value - for child_fragments_for_node in node_rows - for fragment in child_fragments_for_node - ] - ) - # Only keep nodes with more than one child - multi_child_mask = [len(fragments) > 1 for fragments in child_fragments] - multi_child_node_ids = node_ids[multi_child_mask] - multi_child_children_ids = child_fragments[multi_child_mask] - # Store how many children each node has, because we will retrieve all children at once - multi_child_num_children = [ - len(children) for children in multi_child_children_ids - ] - child_fragments_flat = np.array( - [ - frag - for children_of_node in multi_child_children_ids - for frag in children_of_node - ] +def get_multi_child_nodes(cg, chunk_id, node_id_subset=None, chunk_bbox_string=False): + if node_id_subset is None: + range_read = cg.range_read_chunk( + chunk_id, properties=attributes.Hierarchy.Child ) - multi_child_descendants = meshgen_utils.get_downstream_multi_child_nodes( - cg, child_fragments_flat + else: + range_read = cg.client.read_nodes( + node_ids=node_id_subset, properties=attributes.Hierarchy.Child ) - start_index = 0 - multi_child_nodes = {} - for i in range(len(multi_child_node_ids)): - end_index = start_index + multi_child_num_children[i] - descendents_for_current_node = multi_child_descendants[ - start_index:end_index - ] - node_id = multi_child_node_ids[i] + + node_ids = np.array(list(range_read.keys())) + node_rows = np.array(list(range_read.values())) + child_fragments = np.array( + [ + fragment.value + for child_fragments_for_node in node_rows + for fragment in child_fragments_for_node + ], dtype=object + ) + # Filter out node ids that do not have roots (caused by failed ingest tasks) + root_ids = cg.get_roots(node_ids, fail_to_zero=True) + # Only keep nodes with more than one child + multi_child_mask = np.array( + [len(fragments) > 1 for fragments in child_fragments], dtype=bool + ) + root_id_mask = np.array([root_id != 0 for root_id in root_ids], dtype=bool) + multi_child_node_ids = node_ids[multi_child_mask & root_id_mask] + multi_child_children_ids = child_fragments[multi_child_mask & root_id_mask] + # Store how many children each node has, because we will retrieve all children at once + multi_child_num_children = [len(children) for children in multi_child_children_ids] + child_fragments_flat = np.array( + [ + frag + for children_of_node in multi_child_children_ids + for frag in children_of_node + ] + ) + multi_child_descendants = meshgen_utils.get_downstream_multi_child_nodes( + cg, child_fragments_flat + ) + start_index = 0 + multi_child_nodes = {} + for i in range(len(multi_child_node_ids)): + end_index = start_index + multi_child_num_children[i] + descendents_for_current_node = multi_child_descendants[start_index:end_index] + node_id = multi_child_node_ids[i] + if chunk_bbox_string: multi_child_nodes[ f"{node_id}:0:{meshgen_utils.get_chunk_bbox_str(cg, node_id)}" ] = [ f"{c}:0:{meshgen_utils.get_chunk_bbox_str(cg, c)}" for c in descendents_for_current_node ] - start_index = end_index - print("%d out of %d" % (len(multi_child_nodes), len(node_ids))) - result.append((chunk_id, len(multi_child_nodes), len(node_ids))) - if not multi_child_nodes: - print("Nothing to do", cx, cy, cz) - return ", ".join(str(x) for x in result) - - with Storage(os.path.join(cv_path, mesh_dir)) as storage: - vals = multi_child_nodes.values() - fragment_to_fetch = [ - fragment for child_fragments in vals for fragment in child_fragments - ] - if fragment_batch_size is None: - files_contents = storage.get_files(fragment_to_fetch) - else: - files_contents = storage.get_files( - fragment_to_fetch[0:fragment_batch_size] - ) - fragments_in_batch_processed = 0 - batches_processed = 0 - num_fragments_processed = 0 - fragment_map = {} - for i in range(len(files_contents)): - fragment_map[files_contents[i]["filename"]] = files_contents[i] - i = 0 - for new_fragment_id, fragment_ids_to_fetch in multi_child_nodes.items(): - i += 1 - if i % max(1, len(multi_child_nodes) // 10) == 0: - print(f"{i}/{len(multi_child_nodes)}") - - old_fragments = [] - missing_fragments = False - for fragment_id in fragment_ids_to_fetch: - if fragment_batch_size is not None: - fragments_in_batch_processed += 1 - if fragments_in_batch_processed > fragment_batch_size: - fragments_in_batch_processed = 1 - batches_processed += 1 - num_fragments_processed = ( - batches_processed * fragment_batch_size - ) - files_contents = storage.get_files( - fragment_to_fetch[ - num_fragments_processed : num_fragments_processed - + fragment_batch_size - ] - ) - fragment_map = {} - for j in range(len(files_contents)): - fragment_map[ - files_contents[j]["filename"] - ] = files_contents[j] - fragment = fragment_map[fragment_id] - filename = fragment["filename"] - end_of_node_id_index = filename.find(":") - if end_of_node_id_index == -1: - print( - f"Unexpected filename {filename}. Filenames expected in format '\{node_id}:\{lod}:\{meshgen_utils.get_chunk_bbox_str(cg, node_id)}'" - ) - missing_fragments = True - node_id_str = filename[:end_of_node_id_index] - if fragment["content"] is not None and fragment["error"] is None: - try: - old_fragments.append( - { - "mesh": decode_draco_mesh_buffer( - fragment["content"] - ), - "node_id": np.uint64(node_id_str), - } - ) - except: - missing_fragments = True - new_fragment_str = new_fragment_id[ - 0 : new_fragment_id.find(":") - ] - result.append( - f"Decoding failed for {node_id_str} in {new_fragment_str}" - ) - elif cg.get_chunk_layer(np.uint64(node_id_str)) > 2: - result.append(f"{fragment_id} missing for {new_fragment_id}") - - if len(old_fragments) == 0 or missing_fragments: - result.append(f"No meshes for {new_fragment_id}") - continue - - draco_encoding_options = None - for old_fragment in old_fragments: - if draco_encoding_options is None: - draco_encoding_options = transform_draco_fragment_and_return_encoding_options( - cg, old_fragment, layer, mip, chunk_id - ) - else: - encoding_options_for_fragment = transform_draco_fragment_and_return_encoding_options( - cg, old_fragment, layer, mip, chunk_id - ) - np.testing.assert_equal( - draco_encoding_options["quantization_bits"], - encoding_options_for_fragment["quantization_bits"], - ) - np.testing.assert_equal( - draco_encoding_options["quantization_range"], - encoding_options_for_fragment["quantization_range"], - ) - np.testing.assert_array_equal( - draco_encoding_options["quantization_origin"], - encoding_options_for_fragment["quantization_origin"], - ) + else: + multi_child_nodes[multi_child_node_ids[i]] = descendents_for_current_node + start_index = end_index - new_fragment = merge_draco_meshes_across_boundaries( - cg, old_fragments, chunk_id, mip, high_padding - ) + return multi_child_nodes, multi_child_descendants - try: - new_fragment_b = DracoPy.encode_mesh_to_buffer( - new_fragment["vertices"], - new_fragment["faces"], - **draco_encoding_options, - ) - except: - new_fragment_str = new_fragment_id[0 : new_fragment_id.find(":")] - result.append( - f'Bad mesh created for {new_fragment_str}: {len(new_fragment["vertices"])} vertices, {len(new_fragment["faces"])} faces' + +def chunk_stitch_remeshing_task( + cg_name, + chunk_id, + cv_sharded_mesh_dir, + cv_unsharded_mesh_path, + mip=2, + lod=0, + fragment_batch_size=None, + node_id_subset=None, + cg=None, + high_padding=1, +): + """ + For each node with more than one child, create a new fragment by + merging the mesh fragments of the children. + """ + if cg is None: + cg = ChunkedGraph(graph_id=cg_name) + cx, cy, cz = cg.get_chunk_coordinates(chunk_id) + layer = cg.get_chunk_layer(chunk_id) + result = [] + + assert layer > 2 + + print( + "Retrieving children for chunk %s -- (%s, %s, %s, %s)" + % (chunk_id, layer, cx, cy, cz) + ) + + multi_child_nodes, _ = get_multi_child_nodes(cg, chunk_id, node_id_subset, False) + print(f"{len(multi_child_nodes)} nodes with more than one child") + result.append((chunk_id, len(multi_child_nodes))) + if not multi_child_nodes: + print("Nothing to do", cx, cy, cz) + return ", ".join(str(x) for x in result) + + cv = CloudVolume( + f"graphene://https://localhost/segmentation/table/dummy", + mesh_dir=cv_sharded_mesh_dir, + info=meshgen_utils.get_json_info(cg), + ) + + fragments_in_batch_processed = 0 + batches_processed = 0 + num_fragments_processed = 0 + fragment_to_fetch = [ + fragment + for child_fragments in multi_child_nodes.values() + for fragment in child_fragments + ] + cf = CloudFiles(cv_unsharded_mesh_path) + if fragment_batch_size is None: + fragment_map = cv.mesh.get_meshes_on_bypass( + fragment_to_fetch, allow_missing=True + ) + else: + fragment_map = cv.mesh.get_meshes_on_bypass( + fragment_to_fetch[0:fragment_batch_size], allow_missing=True + ) + i = 0 + fragments_d = {} + for new_fragment_id, fragment_ids_to_fetch in multi_child_nodes.items(): + i += 1 + if i % max(1, len(multi_child_nodes) // 10) == 0: + print(f"{i}/{len(multi_child_nodes)}") + + old_fragments = [] + missing_fragments = False + for fragment_id in fragment_ids_to_fetch: + if fragment_batch_size is not None: + fragments_in_batch_processed += 1 + if fragments_in_batch_processed > fragment_batch_size: + fragments_in_batch_processed = 1 + batches_processed += 1 + num_fragments_processed = batches_processed * fragment_batch_size + fragment_map = cv.mesh.get_meshes_on_bypass( + fragment_to_fetch[ + num_fragments_processed : num_fragments_processed + + fragment_batch_size + ], + allow_missing=True, ) - continue - - if WRITING_TO_CLOUD: - storage.put_file( - new_fragment_id, - new_fragment_b, - content_type="application/octet-stream", - compress=False, - cache_control="no-cache", + if fragment_id in fragment_map: + old_frag = fragment_map[fragment_id] + new_old_frag = { + "num_vertices": len(old_frag.vertices), + "vertices": old_frag.vertices, + "faces": old_frag.faces.reshape(-1), + "encoding_options": old_frag.encoding_options, + "encoding_type": "draco", + } + wrapper_object = { + "mesh": new_old_frag, + "node_id": np.uint64(old_frag.segid), + } + old_fragments.append(wrapper_object) + elif cg.get_chunk_layer(np.uint64(fragment_id)) > 2: + missing_fragments = True + result.append(f"{fragment_id} missing for {new_fragment_id}") + + if len(old_fragments) == 0 or missing_fragments: + result.append(f"No meshes for {new_fragment_id}") + continue + + draco_encoding_options = None + for old_fragment in old_fragments: + if draco_encoding_options is None: + draco_encoding_options = ( + transform_draco_fragment_and_return_encoding_options( + cg, old_fragment, layer, mip, chunk_id ) + ) + else: + transform_draco_fragment_and_return_encoding_options( + cg, old_fragment, layer, mip, chunk_id + ) + + new_fragment = merge_draco_meshes_across_boundaries( + cg, old_fragments, chunk_id, mip, high_padding + ) + + try: + new_fragment_b = DracoPy.encode_mesh_to_buffer( + new_fragment["vertices"], + new_fragment["faces"], + **draco_encoding_options, + ) + except: + result.append( + f'Bad mesh created for {new_fragment_id}: {len(new_fragment["vertices"])} ' + f'vertices, {len(new_fragment["faces"])} faces' + ) + continue + + if WRITING_TO_CLOUD: + fragment_name = meshgen_utils.get_chunk_bbox_str(cg, new_fragment_id) + fragment_name = f"{new_fragment_id}:0:{fragment_name}" + fragments_d[new_fragment_id] = fragment_name + cf.put( + fragment_name, + new_fragment_b, + content_type="application/octet-stream", + compress=False, + cache_control="public", + ) + + manifest_cache = ManifestCache(cg.graph_id, initial=False) + manifest_cache.set_fragments(fragments_d) if PRINT_FOR_DEBUGGING: print(", ".join(str(x) for x in result)) return ", ".join(str(x) for x in result) + +def chunk_initial_sharded_stitching_task( + cg_name, chunk_id, mip, cg=None, high_padding=1, cache=True +): + start_existence_check_time = time.time() + if cg is None: + cg = ChunkedGraph(graph_id=cg_name) + + cache_string = "public" if cache else "no-cache" + + layer = cg.get_chunk_layer(chunk_id) + multi_child_nodes, multi_child_descendants = get_multi_child_nodes(cg, chunk_id) + + chunk_to_id_dict = collections.defaultdict(list) + for child_node in multi_child_descendants: + cur_chunk_id = int(cg.get_chunk_id(child_node)) + chunk_to_id_dict[cur_chunk_id].append(child_node) + + cv = CloudVolume( + f"graphene://https://localhost/segmentation/table/dummy", + info=meshgen_utils.get_json_info(cg), + ) + shard_filenames = [] + shard_to_chunk_id = {} + for cur_chunk_id in chunk_to_id_dict: + shard_id = cv.meta.decode_chunk_position_number(cur_chunk_id) + shard_filename = ( + str(cg.get_chunk_layer(cur_chunk_id)) + "/" + str(shard_id) + "-0.shard" + ) + shard_to_chunk_id[shard_filename] = cur_chunk_id + shard_filenames.append(shard_filename) + mesh_dict = {} + + cf = CloudFiles(os.path.join(cv.cloudpath, cv.mesh.meta.mesh_path, "initial")) + files_contents = cf.get(shard_filenames) + for i in range(len(files_contents)): + cur_chunk_id = shard_to_chunk_id[files_contents[i]["path"]] + cur_layer = cg.get_chunk_layer(cur_chunk_id) + if files_contents[i]["content"] is not None: + disassembled_shard = cv.mesh.readers[cur_layer].disassemble_shard( + files_contents[i]["content"] + ) + nodes_in_chunk = chunk_to_id_dict[int(cur_chunk_id)] + for node_in_chunk in nodes_in_chunk: + node_in_chunk_int = int(node_in_chunk) + if node_in_chunk_int in disassembled_shard: + mesh_dict[node_in_chunk_int] = disassembled_shard[node_in_chunk] + del files_contents + + number_frags_proc = 0 + sharding_info = cv.mesh.meta.info["sharding"][str(layer)] + sharding_spec = ShardingSpecification.from_dict(sharding_info) + merged_meshes = {} + biggest_frag = 0 + biggest_frag_vx_ct = 0 + bad_meshes = [] + for new_fragment_id in multi_child_nodes: + fragment_ids_to_fetch = multi_child_nodes[new_fragment_id] + old_fragments = [] + for frag_to_fetch in fragment_ids_to_fetch: + try: + old_fragments.append( + { + "mesh": decode_draco_mesh_buffer(mesh_dict[int(frag_to_fetch)]), + "node_id": np.uint64(frag_to_fetch), + } + ) + except KeyError: + pass + if len(old_fragments) > 0: + draco_encoding_options = None + for old_fragment in old_fragments: + if draco_encoding_options is None: + draco_encoding_options = ( + transform_draco_fragment_and_return_encoding_options( + cg, old_fragment, layer, mip, chunk_id + ) + ) + else: + transform_draco_fragment_and_return_encoding_options( + cg, old_fragment, layer, mip, chunk_id + ) + + new_fragment = merge_draco_meshes_across_boundaries( + cg, old_fragments, chunk_id, mip, high_padding + ) + + if len(new_fragment["vertices"]) > biggest_frag_vx_ct: + biggest_frag = new_fragment_id + biggest_frag_vx_ct = len(new_fragment["vertices"]) + + try: + new_fragment_b = DracoPy.encode_mesh_to_buffer( + new_fragment["vertices"], + new_fragment["faces"], + **draco_encoding_options, + ) + merged_meshes[int(new_fragment_id)] = new_fragment_b + except: + print(f"failed to merge {new_fragment_id}") + bad_meshes.append(new_fragment_id) + pass + number_frags_proc = number_frags_proc + 1 + if number_frags_proc % 1000 == 0: + print(f"number frag proc = {number_frags_proc}") + del mesh_dict + shard_binary = sharding_spec.synthesize_shard(merged_meshes) + shard_filename = cv.mesh.readers[layer].get_filename(chunk_id) + cf = CloudFiles( + os.path.join(cv.cloudpath, cv.mesh.meta.mesh_path, "initial", str(layer)) + ) + cf.put( + shard_filename, + shard_binary, + content_type="application/octet-stream", + compress=False, + cache_control=cache_string, + ) + total_time = time.time() - start_existence_check_time + + ret = { + "chunk_id": chunk_id, + "total_time": total_time, + "biggest_frag": biggest_frag, + "biggest_frag_vx_ct": biggest_frag_vx_ct, + "number_frag": number_frags_proc, + "bad meshes": bad_meshes, + } + return ret diff --git a/pychunkedgraph/meshing/meshgen_utils.py b/pychunkedgraph/meshing/meshgen_utils.py index 075af31c8..711c09322 100644 --- a/pychunkedgraph/meshing/meshgen_utils.py +++ b/pychunkedgraph/meshing/meshgen_utils.py @@ -1,26 +1,35 @@ import re -import numpy as np -import time - -from functools import lru_cache -from cloudvolume import CloudVolume, Storage +import multiprocessing as mp +from time import time +from typing import List +from typing import Dict +from typing import Tuple from typing import Sequence +from functools import lru_cache -from pychunkedgraph.backend import chunkedgraph # noqa +import numpy as np +from cloudvolume import CloudVolume +from cloudvolume.lib import Vec +from multiwrapper import multiprocessing_utils as mu + +from pychunkedgraph.graph.utils.basetypes import NODE_ID # noqa +from ..graph.types import empty_1d def str_to_slice(slice_str: str): match = re.match(r"(\d+)-(\d+)_(\d+)-(\d+)_(\d+)-(\d+)", slice_str) - return (slice(int(match.group(1)), int(match.group(2))), - slice(int(match.group(3)), int(match.group(4))), - slice(int(match.group(5)), int(match.group(6)))) + return ( + slice(int(match.group(1)), int(match.group(2))), + slice(int(match.group(3)), int(match.group(4))), + slice(int(match.group(5)), int(match.group(6))), + ) def slice_to_str(slices) -> str: if isinstance(slices, slice): return "%d-%d" % (slices.start, slices.stop) else: - return '_'.join(map(slice_to_str, slices)) + return "_".join(map(slice_to_str, slices)) def get_chunk_bbox(cg, chunk_id: np.uint64): @@ -41,7 +50,7 @@ def get_mesh_name(cg, node_id: np.uint64) -> str: @lru_cache(maxsize=None) def get_segmentation_info(cg) -> dict: - return cg.dataset_info + return cg.meta.dataset_info def get_mesh_block_shape(cg, graphlayer: int) -> np.ndarray: @@ -50,11 +59,12 @@ def get_mesh_block_shape(cg, graphlayer: int) -> np.ndarray: the same region as a ChunkedGraph chunk at layer `graphlayer`. """ # Segmentation is not always uniformly downsampled in all directions. - return cg.chunk_size * cg.fan_out ** np.max([0, graphlayer - 2]) + return np.array( + cg.meta.graph_config.CHUNK_SIZE + ) * cg.meta.graph_config.FANOUT ** np.max([0, graphlayer - 2]) -def get_mesh_block_shape_for_mip(cg, graphlayer: int, - source_mip: int) -> np.ndarray: +def get_mesh_block_shape_for_mip(cg, graphlayer: int, source_mip: int) -> np.ndarray: """ Calculate the dimensions of a segmentation block at `source_mip` that covers the same region as a ChunkedGraph chunk at layer `graphlayer`. @@ -62,18 +72,20 @@ def get_mesh_block_shape_for_mip(cg, graphlayer: int, info = get_segmentation_info(cg) # Segmentation is not always uniformly downsampled in all directions. - scale_0 = info['scales'][0] - scale_mip = info['scales'][source_mip] - distortion = np.floor_divide(scale_mip['resolution'], scale_0['resolution']) + scale_0 = info["scales"][0] + scale_mip = info["scales"][source_mip] + distortion = np.floor_divide(scale_mip["resolution"], scale_0["resolution"]) - graphlayer_chunksize = cg.chunk_size * cg.fan_out ** np.max([0, graphlayer - 2]) + graphlayer_chunksize = np.array( + cg.meta.graph_config.CHUNK_SIZE + ) * cg.meta.graph_config.FANOUT ** np.max([0, graphlayer - 2]) - return np.floor_divide(graphlayer_chunksize, distortion, dtype=np.int, - casting='unsafe') + return np.floor_divide( + graphlayer_chunksize, distortion, dtype=int, casting="unsafe" + ) -def get_downstream_multi_child_node(cg, node_id: np.uint64, - stop_layer: int = 1): +def get_downstream_multi_child_node(cg, node_id: np.uint64, stop_layer: int = 1): """ Return the first descendant of `node_id` (including itself) with more than one child, or the first descendant of `node_id` (including itself) on or @@ -93,76 +105,75 @@ def get_downstream_multi_child_node(cg, node_id: np.uint64, return get_downstream_multi_child_node(cg, children[0], stop_layer) -def get_downstream_multi_child_nodes(cg, node_ids: Sequence[np.uint64], require_children=True): +def get_downstream_multi_child_nodes( + cg, node_ids: Sequence[np.uint64], require_children=True +): """ Return the first descendant of `node_ids` (including themselves) with more than - one child, or the first descendant of `node_id` (including itself) on or + one child, or the first descendant of `node_ids` (including themselves) on or below layer 2. """ # FIXME: Make stop_layer configurable stop_layer = 2 - node_ids_to_return = np.copy(node_ids) def recursive_helper(cur_node_ids): - stop_layer_mask = np.array([cg.get_chunk_layer(node_id) > stop_layer for node_id in cur_node_ids]) + cur_node_ids, unique_to_original = np.unique(cur_node_ids, return_inverse=True) + stop_layer_mask = np.array( + [cg.get_chunk_layer(node_id) > stop_layer for node_id in cur_node_ids] + ) if np.any(stop_layer_mask): node_to_children_dict = cg.get_children(cur_node_ids[stop_layer_mask]) - children_array = np.array(list(node_to_children_dict.values())) - if require_children and len(children_array) < len(cur_node_ids[stop_layer_mask]): - raise ValueError('Not all node_ids have children. May be mixing node_ids from different generations.') - only_child_mask = np.array([len(children_for_node) == 1 for children_for_node in children_array]) + children_array = np.array( + list(node_to_children_dict.values()), dtype=object + ) + only_child_mask = np.array( + [len(children_for_node) == 1 for children_for_node in children_array] + ) only_children = children_array[only_child_mask].astype(np.uint64).ravel() if np.any(only_child_mask): temp_array = cur_node_ids[stop_layer_mask] temp_array[only_child_mask] = recursive_helper(only_children) cur_node_ids[stop_layer_mask] = temp_array - return cur_node_ids - - return recursive_helper(node_ids_to_return) - - -def get_highest_child_nodes_with_meshes(cg, - node_id: np.uint64, - stop_layer=1, - start_layer=None, - verify_existence=False, - bounding_box=None): - if start_layer is None: - start_layer = cg.n_layers - - candidates = cg.get_subgraph_nodes( - node_id, - bounding_box=bounding_box, - bb_is_coordinate=True, - return_layers=[start_layer]) - - if verify_existence: - valid_node_ids = [] - with Storage(cg.cv_mesh_path) as stor: - while True: - filenames = [get_mesh_name(cg, c) for c in candidates] - - time_start = time.time() - existence_dict = stor.files_exist(filenames) - print("Existence took: %.3fs" % (time.time() - time_start)) - - missing_meshes = [] - for mesh_key in existence_dict: - node_id = np.uint64(mesh_key.split(':')[0]) - if existence_dict[mesh_key]: - valid_node_ids.append(node_id) - else: - if cg.get_chunk_layer(node_id) > stop_layer: - missing_meshes.append(node_id) - - time_start = time.time() - if missing_meshes: - candidates = cg.get_children(missing_meshes, flatten=True) - else: - break - print("ChunkedGraph lookup took: %.3fs" % (time.time() - time_start)) + return cur_node_ids[unique_to_original] - else: - valid_node_ids = candidates + return recursive_helper(node_ids) + + +def get_json_info(cg): + from json import loads, dumps + + dataset_info = cg.meta.dataset_info + dummy_app_info = {"app": {"supported_api_versions": [0, 1]}} + info = {**dataset_info, **dummy_app_info} + info["mesh"] = cg.meta.custom_data.get("mesh", {}).get("dir", "graphene_meshes") + info_str = dumps(info) + return loads(info_str) + + +def get_ws_seg_for_chunk(cg, chunk_id, mip, overlap_vx=1): + cv = CloudVolume(cg.meta.cv.cloudpath, mip=mip, fill_missing=True) + mip_diff = mip - cg.meta.cv.mip + + mip_chunk_size = np.array(cg.meta.graph_config.CHUNK_SIZE, dtype=int) / np.array( + [2 ** mip_diff, 2 ** mip_diff, 1] + ) + mip_chunk_size = mip_chunk_size.astype(int) + + chunk_start = ( + cg.meta.cv.mip_voxel_offset(mip) + + cg.get_chunk_coordinates(chunk_id) * mip_chunk_size + ) + chunk_end = chunk_start + mip_chunk_size + overlap_vx + chunk_end = Vec.clamp( + chunk_end, + cg.meta.cv.mip_voxel_offset(mip), + cg.meta.cv.mip_voxel_offset(mip) + cg.meta.cv.mip_volume_size(mip), + ) + + ws_seg = cv[ + chunk_start[0] : chunk_end[0], + chunk_start[1] : chunk_end[1], + chunk_start[2] : chunk_end[2], + ].squeeze() - return valid_node_ids + return ws_seg diff --git a/pychunkedgraph/meshing/meshing_batch.py b/pychunkedgraph/meshing/meshing_batch.py new file mode 100644 index 000000000..6f40fb0a0 --- /dev/null +++ b/pychunkedgraph/meshing/meshing_batch.py @@ -0,0 +1,65 @@ +import argparse, os +import numpy as np +from cloudvolume import CloudVolume +from cloudfiles import CloudFiles +from taskqueue import TaskQueue, LocalTaskQueue + +from pychunkedgraph.graph.chunkedgraph import ChunkedGraph # noqa +from pychunkedgraph.meshing.meshing_sqs import MeshTask +from pychunkedgraph.meshing import meshgen_utils # noqa + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--queue_name', type=str, default=None) + parser.add_argument('--chunk_start', nargs=3, type=int) + parser.add_argument('--chunk_end', nargs=3, type=int) + parser.add_argument('--cg_name', type=str) + parser.add_argument('--layer', type=int) + parser.add_argument('--mip', type=int) + parser.add_argument('--skip_cache', action='store_true') + parser.add_argument('--overwrite', type=bool, default=False) + + args = parser.parse_args() + cache = not args.skip_cache + + cg = ChunkedGraph(graph_id=args.cg_name) + cv = CloudVolume( + f"graphene://https://localhost/segmentation/table/dummy", + info=meshgen_utils.get_json_info(cg), + ) + dst = os.path.join( + cv.cloudpath, cv.mesh.meta.mesh_path, "initial", str(args.layer) + ) + cf = CloudFiles(dst) + if len(list(cf.list())) > 0 and not args.overwrite: + raise ValueError(f"Destination {dst} is not empty. Use `--overwrite true` to proceed anyway.") + + chunks_arr = [] + for x in range(args.chunk_start[0],args.chunk_end[0]): + for y in range(args.chunk_start[1], args.chunk_end[1]): + for z in range(args.chunk_start[2], args.chunk_end[2]): + chunks_arr.append((x, y, z)) + + np.random.shuffle(chunks_arr) + + class MeshTaskIterator(object): + def __init__(self, chunks): + self.chunks = chunks + def __iter__(self): + if args.overwrite: + meshed = set() + else: + meshed = set(cf.list()) + for chunk in self.chunks: + chunk_id = cg.get_chunk_id(layer=args.layer, x=chunk[0], y=chunk[1], z=chunk[2]) + shard_filename = cv.mesh.readers[args.layer].get_filename(chunk_id) + if shard_filename in meshed: + continue + yield MeshTask(args.cg_name, args.layer, int(chunk_id), args.mip, cache) + + if args.queue_name is not None: + with TaskQueue(args.queue_name) as tq: + tq.insert_all(MeshTaskIterator(chunks_arr)) + else: + tq = LocalTaskQueue(parallel=1) + tq.insert_all(MeshTaskIterator(chunks_arr)) \ No newline at end of file diff --git a/pychunkedgraph/meshing/meshing_sqs.py b/pychunkedgraph/meshing/meshing_sqs.py new file mode 100644 index 000000000..b302a1744 --- /dev/null +++ b/pychunkedgraph/meshing/meshing_sqs.py @@ -0,0 +1,29 @@ +from taskqueue import RegisteredTask +from pychunkedgraph.meshing import meshgen +import numpy as np + + +class MeshTask(RegisteredTask): + def __init__(self, cg_name, layer, chunk_id, mip, cache=True): + super().__init__(cg_name, layer, chunk_id, mip, cache) + + def execute(self): + cg_name = self.cg_name + chunk_id = np.uint64(self.chunk_id) + mip = self.mip + layer = self.layer + if layer == 2: + result = meshgen.chunk_initial_mesh_task( + cg_name, + chunk_id, + None, + mip=mip, + sharded=True, + cache=self.cache + ) + else: + result = meshgen.chunk_initial_sharded_stitching_task( + cg_name, chunk_id, mip, cache=self.cache + ) + print(result) + diff --git a/pychunkedgraph/meshing/meshing_test_temp.py b/pychunkedgraph/meshing/meshing_test_temp.py deleted file mode 100644 index 27665300e..000000000 --- a/pychunkedgraph/meshing/meshing_test_temp.py +++ /dev/null @@ -1,338 +0,0 @@ -import time -import click -import redis - -from flask import current_app -from flask.cli import AppGroup -from pychunkedgraph.backend.chunkedgraph import ChunkedGraph -from pychunkedgraph.meshing import meshgen -import cloudvolume -import numpy as np -from datetime import datetime - -ingest_cli = AppGroup('mesh') - -num_messages = 0 -messages = [] -cg = ChunkedGraph('fly_v31') -def handlerino_write_to_cloud(*args, **kwargs): - global num_messages - num_messages = num_messages + 1 - print(num_messages, args[0]['data']) - messages.append(args[0]['data']) - with open('output.txt', 'a') as f: - f.write(str(args[0]['data']) + '\n') - if num_messages == 1000: - print('DONE') - cv_path = cg._cv_path - with cloudvolume.Storage(cv_path) as storage: - storage.put_file( - file_path='frag_test/frag_test_summary_no_dust_threshold', - content=','.join(map(str, messages)), - compress=False, - cache_control='no-cache' - ) - - -def exc_handler(job, exc_type, exc_value, tb): - with open('exceptions.txt', 'a') as f: - f.write(str(job.args)+'\n') - f.write(str(exc_type) + '\n') - f.write(str(exc_value) + '\n') - - -def handlerino_print(*args, **kwargs): - with open('output.txt', 'a') as f: - f.write(str(args[0]['data']) + '\n') - - -def handlerino_periodically_write_to_cloud(*args, **kwargs): - global num_messages - num_messages = num_messages + 1 - print(num_messages, args[0]['data']) - messages.append(args[0]['data']) - with open('output.txt', 'a') as f: - f.write(str(args[0]['data']) + '\n') - if num_messages % 1000 == 0: - print('Writing result data to cloud') - cv_path = cg._cv_path - filename = f'{datetime.now()}_meshes_{num_messages}' - with cloudvolume.Storage(cv_path) as storage: - storage.put_file( - file_path=f'meshing_run_data/{filename}', - content=','.join(map(str, messages)), - compress=False, - cache_control='no-cache' - ) - - -@ingest_cli.command('mesh_chunks') -@click.argument('layer', type=int) -@click.argument('x_start', type=int) -@click.argument('y_start', type=int) -@click.argument('z_start', type=int) -@click.argument('x_end', type=int) -@click.argument('y_end', type=int) -@click.argument('z_end', type=int) -@click.argument('fragment_batch_size', type=int) -@click.argument('mesh_dir', type=str, default=None) -def mesh_chunks(layer, x_start, y_start, z_start, x_end, y_end, z_end, fragment_batch_size, mesh_dir): - print(f'Queueing...') - chunk_pubsub = current_app.redis.pubsub() - chunk_pubsub.subscribe(**{'mesh_frag_test_channel': handlerino_print}) - thread = chunk_pubsub.run_in_thread(sleep_time=0.1) - - cg = ChunkedGraph('fly_v31') - for x in range(x_start,x_end): - for y in range(y_start, y_end): - for z in range(z_start, z_end): - chunk_id = cg.get_chunk_id(None, layer, x, y, z) - current_app.test_q.enqueue( - meshgen.chunk_mesh_task_new_remapping, - job_timeout='20m', - args=( - cg.get_serialized_info(), - chunk_id, - cg._cv_path, - mesh_dir - ), - kwargs={ - # 'cv_mesh_dir': 'mesh_testing/initial_testrun_meshes', - 'mip': 1, - 'max_err': 320, - 'fragment_batch_size': fragment_batch_size - # 'dust_threshold': 100 - }) - - return 'Queued' - -@ingest_cli.command('mesh_chunks_shuffled') -@click.argument('layer', type=int) -@click.argument('x_start', type=int) -@click.argument('y_start', type=int) -@click.argument('z_start', type=int) -@click.argument('x_end', type=int) -@click.argument('y_end', type=int) -@click.argument('z_end', type=int) -@click.argument('fragment_batch_size', type=int, default=None) -@click.argument('mesh_dir', type=str, default=None) -def mesh_chunks_shuffled(layer, x_start, y_start, z_start, x_end, y_end, z_end, fragment_batch_size, mesh_dir): - print(f'Queueing...') - chunk_pubsub = current_app.redis.pubsub() - chunk_pubsub.subscribe(**{'mesh_frag_test_channel': handlerino_periodically_write_to_cloud}) - - cg = ChunkedGraph('fly_v31') - chunks_arr = [] - for x in range(x_start,x_end): - for y in range(y_start, y_end): - for z in range(z_start, z_end): - chunks_arr.append((x, y, z)) - - print(f'Total jobs: {len(chunks_arr)}') - - np.random.shuffle(chunks_arr) - - for chunk in chunks_arr: - chunk_id = cg.get_chunk_id(None, layer, chunk[0], chunk[1], chunk[2]) - current_app.test_q.enqueue( - meshgen.chunk_mesh_task_new_remapping, - job_timeout='300m', - args=( - cg.get_serialized_info(), - chunk_id, - cg._cv_path, - mesh_dir - ), - kwargs={ - # 'cv_mesh_dir': 'mesh_testing/initial_testrun_meshes', - 'mip': 1, - 'max_err': 320, - 'fragment_batch_size': fragment_batch_size - # 'dust_threshold': 100 - }) - - print(f'Queued jobs: {len(current_app.test_q)}') - thread = chunk_pubsub.run_in_thread(sleep_time=0.1) - - return 'Queued' - - -@ingest_cli.command('mesh_chunks_from_file') -@click.argument('filename', type=str) -def mesh_chunks_from_file(filename): - print(f'Queueing...') - chunk_pubsub = current_app.redis.pubsub() - chunk_pubsub.subscribe(**{'mesh_frag_test_channel': handlerino_periodically_write_to_cloud}) - thread = chunk_pubsub.run_in_thread(sleep_time=0.1) - cg = ChunkedGraph('fly_v31') - chunk_ids = [] - with open(filename, 'r') as f: - line = f.readline() - while line: - chunk_ids.append(np.uint64(line)) - line = f.readline() - for chunk_id in chunk_ids: - current_app.test_q.enqueue( - meshgen.chunk_mesh_task_new_remapping, - job_timeout='20m', - args=( - cg.get_serialized_info(), - chunk_id, - cg._cv_path - ), - kwargs={ - # 'cv_mesh_dir': 'mesh_testing/initial_testrun_meshes', - 'mip': 1, - 'max_err': 320 - # 'dust_threshold': 100 - }) - - return 'Queued' - - -@ingest_cli.command('mesh_chunks_exclude_file') -@click.argument('layer', type=int) -@click.argument('x_start', type=int) -@click.argument('y_start', type=int) -@click.argument('z_start', type=int) -@click.argument('x_end', type=int) -@click.argument('y_end', type=int) -@click.argument('z_end', type=int) -@click.argument('filename', type=str) -@click.argument('fragment_batch_size', type=int, default=None) -@click.argument('mesh_dir', type=str, default=None) -def mesh_chunks_exclude_file(layer, x_start, y_start, z_start, x_end, y_end, z_end, filename, fragment_batch_size, mesh_dir): - print(f'Queueing...') - chunk_pubsub = current_app.redis.pubsub() - chunk_pubsub.subscribe(**{'mesh_frag_test_channel': handlerino_periodically_write_to_cloud}) - cg = ChunkedGraph('fly_v31') - chunk_ids = [] - with open(filename, 'r') as f: - line = f.readline() - while line: - chunk_ids.append(np.uint64(line)) - line = f.readline() - - chunks_arr = [] - for x in range(x_start,x_end): - for y in range(y_start, y_end): - for z in range(z_start, z_end): - chunk_id = cg.get_chunk_id(None, layer, x, y, z) - if not chunk_id in chunk_ids: - chunks_arr.append(chunk_id) - - print(f'Total jobs: {len(chunks_arr)}') - np.random.shuffle(chunks_arr) - - for chunk_id in chunks_arr: - current_app.test_q.enqueue( - meshgen.chunk_mesh_task_new_remapping, - job_timeout='180m', - args=( - cg.get_serialized_info(), - chunk_id, - cg._cv_path, - mesh_dir - ), - kwargs={ - # 'cv_mesh_dir': 'mesh_testing/initial_testrun_meshes', - 'mip': 1, - 'max_err': 320, - 'fragment_batch_size': fragment_batch_size - # 'dust_threshold': 100 - }) - - print(f'Queued jobs: {len(current_app.test_q)}') - thread = chunk_pubsub.run_in_thread(sleep_time=0.1) - - return 'Queued' - - -@ingest_cli.command('mesh_chunk_ids_shuffled') -@click.argument('chunk_ids_string') -# chunk_ids_string = comma separated string list of chunk ids, e.g. "376263874141234936,513410357520258150" -def mesh_chunk_ids_shuffled(chunk_ids_string): - print(f'Queueing...') - chunk_pubsub = current_app.redis.pubsub() - chunk_pubsub.subscribe(**{'mesh_frag_test_channel': handlerino_periodically_write_to_cloud}) - thread = chunk_pubsub.run_in_thread(sleep_time=0.1) - - cg = ChunkedGraph('fly_v31') - - chunk_ids = np.uint64(chunk_ids_string.split(',')) - np.random.shuffle(chunk_ids) - - for chunk_id in chunk_ids: - current_app.test_q.enqueue( - meshgen.chunk_mesh_task_new_remapping, - job_timeout='20m', - args=( - cg.get_serialized_info(), - chunk_id, - cg._cv_path - ), - kwargs={ - # 'cv_mesh_dir': 'mesh_testing/initial_testrun_meshes', - 'mip': 1, - 'max_err': 320 - # 'dust_threshold': 100 - }) - - return 'Queued' - - -@ingest_cli.command('frag_test') -@click.argument('n', type=int) -@click.argument('layer', type=int) -def mesh_frag_test(n, layer): - print(f'Queueing...') - chunk_pubsub = current_app.redis.pubsub() - chunk_pubsub.subscribe(**{'mesh_frag_test_channel': handlerino_write_to_cloud}) - thread = chunk_pubsub.run_in_thread(sleep_time=0.1) - - cg = ChunkedGraph('fly_v31') - new_info = cg.cv.info - - dataset_size = np.array(new_info['scales'][0]['size']) - dim_in_chunks = np.ceil(dataset_size / new_info['graph']['chunk_size']) - num_chunks = np.prod(dim_in_chunks, dtype=np.int32) - rand_chunks = np.random.choice(num_chunks, n, replace=False) - for chunk in rand_chunks: - x_span = dim_in_chunks[1] * dim_in_chunks[2] - x = np.floor(chunk / x_span) - rem = chunk - (x * x_span) - y = np.floor(rem / dim_in_chunks[2]) - rem = rem - (y * dim_in_chunks[2]) - z = rem - chunk_id = cg.get_chunk_id(None, layer, np.int32(x), np.int32(y), np.int32(z)) - current_app.test_q.enqueue( - meshgen.chunk_mesh_task_new_remapping, - job_timeout='60m', - args=( - cg.get_serialized_info(), - chunk_id, - cg._cv_path - ), - kwargs={ - # 'cv_mesh_dir': 'mesh_testing/initial_testrun_meshes', - 'mip': 1, - 'max_err': 320, - 'return_frag_count': True - }) - - return 'Queued' - - -@ingest_cli.command('listen') -@click.argument('channel', default='mesh_frag_test_channel') -def mesh_listen(channel): - print(f'Queueing...') - chunk_pubsub = current_app.redis.pubsub() - chunk_pubsub.subscribe(**{channel: handlerino_periodically_write_to_cloud}) - thread = chunk_pubsub.run_in_thread(sleep_time=0.1) - - return 'Queued' - - -def init_mesh_cmds(app): - app.cli.add_command(ingest_cli) diff --git a/pychunkedgraph/rechunking/transformer.py b/pychunkedgraph/rechunking/transformer.py deleted file mode 100644 index 5f4ed3707..000000000 --- a/pychunkedgraph/rechunking/transformer.py +++ /dev/null @@ -1,209 +0,0 @@ -import itertools -import numpy as np -import glob -import os -import time - -import cloudvolume - -from pychunkedgraph.creator import creator_utils -from multiwrapper import multiprocessing_utils as mu - - -def _rewrite_segmentation_thread(args): - file_paths, from_url, to_url = args - - from_cv = cloudvolume.CloudVolume(from_url) - to_cv = cloudvolume.CloudVolume(to_url, bounded=False) - - assert 'svenmd' in to_url - - n_file_paths = len(file_paths) - - time_start = time.time() - for i_fp, fp in enumerate(file_paths): - if i_fp % 10 == 5: - dt = time.time() - time_start - eta = dt / i_fp * n_file_paths - dt - print("%d / %d - dt: %.3fs - eta: %.3fs" % - (i_fp, n_file_paths, dt, eta)) - - rewrite_single_segmentation_block(fp, from_cv=from_cv, to_cv=to_cv) - - -def rewrite_single_segmentation_block(file_path, from_cv=None, to_cv=None, - from_url=None, to_url=None): - if from_cv is None: - assert from_url is not None - from_cv = cloudvolume.CloudVolume(from_url) - - if to_cv is None: - assert to_url is not None - assert 'svenmd' in to_url - to_cv = cloudvolume.CloudVolume(to_url, bounded=False) - - dx, dy, dz, _ = os.path.basename(file_path).split("_") - - x_start, x_end = np.array(dx.split("-"), dtype=np.int) - y_start, y_end = np.array(dy.split("-"), dtype=np.int) - z_start, z_end = np.array(dz.split("-"), dtype=np.int) - - bbox = to_cv.bounds.to_list()[3:] - if x_end > bbox[0]: - x_end = bbox[0] - - if y_end > bbox[1]: - y_end = bbox[1] - - if z_end > bbox[2]: - z_end = bbox[2] - - seg = from_cv[x_start: x_end, y_start: y_end, z_start: z_end] - mapping = creator_utils.read_mapping_h5(file_path) - - if 0 in seg and not 0 in mapping[:, 0]: - mapping = np.concatenate(([np.array([[0, 0]], dtype=np.uint64), mapping])) - - sort_idx = np.argsort(mapping[:, 0]) - idx = np.searchsorted(mapping[:, 0], seg, sorter=sort_idx) - out = np.asarray(mapping[:, 1])[sort_idx][idx] - - # print(out.shape, x_start, x_end, y_start, y_end, z_start, z_end) - to_cv[x_start: x_end, y_start: y_end, z_start: z_end] = out - - -def rewrite_segmentation(dataset_name, n_threads=64, n_units_per_thread=None): - if dataset_name == "pinky": - cv_url = "gs://nkem/pinky40_v11/mst_trimmed_sem_remap/region_graph/" - from_url = "gs://neuroglancer/pinky40_v11/watershed/" - to_url = "gs://neuroglancer/svenmd/pinky40_v11/watershed/" - elif dataset_name == "basil": - cv_url = "gs://nkem/basil_4k_oldnet/region_graph/" - from_url = "gs://neuroglancer/ranl/basil_4k_oldnet/ws/" - to_url = "gs://neuroglancer/svenmd/basil_4k_oldnet_cg/watershed/" - else: - raise Exception("Dataset unknown") - - file_paths = np.sort(glob.glob(creator_utils.dir_from_layer_name( - creator_utils.layer_name_from_cv_url(cv_url)) + "/*rg2cg*")) - - if n_units_per_thread is None: - file_path_blocks = np.array_split(file_paths, n_threads*3) - else: - n_blocks = int(np.ceil(len(file_paths) / n_units_per_thread)) - file_path_blocks = np.array_split(file_paths, n_blocks) - - multi_args = [] - for fp_block in file_path_blocks: - multi_args.append([fp_block, from_url, to_url]) - - # Run parallelizing - if n_threads == 1: - mu.multiprocess_func(_rewrite_segmentation_thread, multi_args, - n_threads=n_threads, verbose=True, - debug=n_threads == 1) - else: - mu.multisubprocess_func(_rewrite_segmentation_thread, multi_args, - n_threads=n_threads) - - -def _rewrite_image_thread(args): - start_coordinates, end_coordinates, block_size, from_url, to_url, mip = args - - from_cv = cloudvolume.CloudVolume(from_url, mip=mip) - to_cv = cloudvolume.CloudVolume(to_url, bounded=False, mip=mip) - - assert 'svenmd' in to_url - - coordinate_iter = itertools.product(np.arange(start_coordinates[0], end_coordinates[0], block_size[0]), - np.arange(start_coordinates[1], end_coordinates[1], block_size[1]), - np.arange(start_coordinates[2], end_coordinates[2], block_size[2])) - - for coordinate in coordinate_iter: - rewrite_single_image_block(coordinate, block_size, from_cv=from_cv, - to_cv=to_cv) - - -def rewrite_single_image_block(coordinate, block_size, from_cv=None, to_cv=None, - from_url=None, to_url=None, mip=None): - if from_cv is None: - assert from_url is not None and mip is not None - from_cv = cloudvolume.CloudVolume(from_url, mip=mip) - - if to_cv is None: - assert to_url is not None and mip is not None - assert 'svenmd' in to_url - to_cv = cloudvolume.CloudVolume(to_url, bounded=False, mip=mip, - compress=False) - - x_start = coordinate[0] - x_end = coordinate[0] + block_size[0] - y_start = coordinate[1] - y_end = coordinate[1] + block_size[1] - z_start = coordinate[2] - z_end = coordinate[2] + block_size[2] - - bbox = to_cv.bounds.to_list()[3:] - if x_end > bbox[0]: - x_end = bbox[0] - - if y_end > bbox[1]: - y_end = bbox[1] - - if z_end > bbox[2]: - z_end = bbox[2] - - print(x_start, y_start, z_start, x_end, y_end, z_end) - - img = from_cv[x_start: x_end, y_start: y_end, z_start: z_end] - to_cv[x_start: x_end, y_start: y_end, z_start: z_end] = img - - -def rechunk_dataset(dataset_name, block_size=(1024, 1024, 64), n_threads=64, - mip=0): - if dataset_name == "pinky40em": - from_url = "gs://neuroglancer/pinky40_v11/image_rechunked/" - to_url = "gs://neuroglancer/svenmd/pinky40_v11/image_512_512_32/" - elif dataset_name == "pinky100seg": - from_url = "gs://neuroglancer/nkem/pinky100_v0/ws/lost_no-random/bbox1_0/" - to_url = "gs://neuroglancer/svenmd/pinky100_v0/ws/lost_no-random/bbox1_0_64_64_16/" - elif dataset_name == "basil": - raise() - else: - raise Exception("Dataset unknown") - - from_cv = cloudvolume.CloudVolume(from_url, mip=mip) - - dataset_bounds = np.array(from_cv.bounds.to_list()) - block_size = np.array(list(block_size)) - - super_block_size = block_size * 2 - - coordinate_iter = itertools.product(np.arange(dataset_bounds[0], - dataset_bounds[3], - super_block_size[0]), - np.arange(dataset_bounds[1], - dataset_bounds[4], - super_block_size[1]), - np.arange(dataset_bounds[2], - dataset_bounds[5], - super_block_size[2])) - coordinates = np.array(list(coordinate_iter)) - - multi_args = [] - for coordinate in coordinates: - end_coordinate = coordinate + super_block_size - m = end_coordinate > dataset_bounds[3:] - end_coordinate[m] = dataset_bounds[3:][m] - - multi_args.append([coordinate, end_coordinate, block_size, - from_url, to_url, mip]) - - # Run parallelizing - if n_threads == 1: - mu.multiprocess_func(_rewrite_image_thread, multi_args, - n_threads=n_threads, verbose=True, - debug=n_threads == 1) - else: - mu.multisubprocess_func(_rewrite_image_thread, multi_args, - n_threads=n_threads) \ No newline at end of file diff --git a/pychunkedgraph/repair/__init__.py b/pychunkedgraph/repair/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pychunkedgraph/repair/edits.py b/pychunkedgraph/repair/edits.py new file mode 100644 index 000000000..cb403a380 --- /dev/null +++ b/pychunkedgraph/repair/edits.py @@ -0,0 +1,63 @@ +# pylint: disable=protected-access,missing-function-docstring,invalid-name,wrong-import-position + +from datetime import timedelta + +from pychunkedgraph.graph import ChunkedGraph +from pychunkedgraph.graph.attributes import Concurrency +from pychunkedgraph.graph.operation import GraphEditOperation + + +def _get_previous_log_ts(cg, operation): + log, previous_ts = cg.client.read_log_entry(operation - 1) + if log: + return previous_ts + return _get_previous_log_ts(cg, operation - 1) + + +def repair_operation( + cg: ChunkedGraph, + operation_id: int, + unlock: bool = False, + use_preceding_edit_ts=True, +) -> GraphEditOperation.Result: + operation = GraphEditOperation.from_operation_id( + cg, operation_id, multicut_as_split=False, privileged_mode=True + ) + + _, current_ts = cg.client.read_log_entry(operation_id) + parent_ts = current_ts - timedelta(milliseconds=10) + if operation_id > 1 and use_preceding_edit_ts: + previous_ts = _get_previous_log_ts(cg, operation_id) + parent_ts = previous_ts + timedelta(milliseconds=100) + + result = operation.execute( + operation_id=operation_id, + parent_ts=parent_ts, + override_ts=current_ts + timedelta(milliseconds=1), + ) + old_roots = operation._update_root_ids() + + if unlock: + for root_ in old_roots: + cg.client.unlock_root(root_, result.operation_id) + cg.client.unlock_indefinitely_locked_root(root_, result.operation_id) + return result + + +if __name__ == "__main__": + op_ids_to_retry = [...] + locked_roots = [...] + + cg = ChunkedGraph(graph_id="") + node_attrs = cg.client.read_nodes(node_ids=locked_roots) + for node_id, attrs in node_attrs.items(): + if Concurrency.IndefiniteLock in attrs: + locked_op = attrs[Concurrency.IndefiniteLock][0].value + op_ids_to_retry.append(locked_op) + print(f"{node_id} indefinitely locked by op {locked_op}") + print(f"total to retry: {len(op_ids_to_retry)}") + + logs = cg.client.read_log_entries(op_ids_to_retry) + for op_id, log in logs.items(): + print(f"repairing {op_id}") + repair_operation(cg, log, op_id) diff --git a/pychunkedgraph/repair/fake_edges.py b/pychunkedgraph/repair/fake_edges.py new file mode 100644 index 000000000..b58b93fb9 --- /dev/null +++ b/pychunkedgraph/repair/fake_edges.py @@ -0,0 +1,78 @@ +# pylint: disable=protected-access,missing-function-docstring,invalid-name,wrong-import-position + +""" +Replay merge operations to check if fake edges need to be added. +""" + +from datetime import datetime +from datetime import timedelta +from os import environ +from typing import Optional + +environ["BIGTABLE_PROJECT"] = "<>" +environ["BIGTABLE_INSTANCE"] = "<>" +environ["GOOGLE_APPLICATION_CREDENTIALS"] = "" + +from pychunkedgraph.graph import edits +from pychunkedgraph.graph import ChunkedGraph +from pychunkedgraph.graph.operation import GraphEditOperation +from pychunkedgraph.graph.operation import MergeOperation +from pychunkedgraph.graph.utils.generic import get_bounding_box as get_bbox + + +def _add_fake_edges(cg: ChunkedGraph, operation_id: int, operation_log: dict) -> bool: + operation = GraphEditOperation.from_operation_id( + cg, operation_id, multicut_as_split=False + ) + + if not isinstance(operation, MergeOperation): + return False + + ts = operation_log["timestamp"] + parent_ts = ts - timedelta(seconds=0.1) + override_ts = (ts + timedelta(microseconds=(ts.microsecond % 1000) + 10),) + + root_ids = set( + cg.get_roots( + operation.added_edges.ravel(), assert_roots=True, time_stamp=parent_ts + ) + ) + + bbox = get_bbox( + operation.source_coords, operation.sink_coords, operation.bbox_offset + ) + edges = cg.get_subgraph( + root_ids, + bbox=bbox, + bbox_is_coordinate=True, + edges_only=True, + ) + + inactive_edges = edits.merge_preprocess( + cg, + subgraph_edges=edges, + supervoxels=operation.added_edges.ravel(), + parent_ts=parent_ts, + ) + + _, fake_edge_rows = edits.check_fake_edges( + cg, + atomic_edges=operation.added_edges, + inactive_edges=inactive_edges, + time_stamp=override_ts, + parent_ts=parent_ts, + ) + + cg.client.write(fake_edge_rows) + return len(fake_edge_rows) > 0 + + +def add_fake_edges( + graph_id: str, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, +): + cg = ChunkedGraph(graph_id=graph_id) + logs = cg.client.read_log_entries(start_time=start_time, end_time=end_time) + for _id, _log in logs.items(): + _add_fake_edges(cg, _id, _log) diff --git a/pychunkedgraph/tests/__init__.py b/pychunkedgraph/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pychunkedgraph/tests/data/sv_affinity.npy b/pychunkedgraph/tests/data/sv_affinity.npy new file mode 100644 index 000000000..59ab66708 Binary files /dev/null and b/pychunkedgraph/tests/data/sv_affinity.npy differ diff --git a/pychunkedgraph/tests/data/sv_area.npy b/pychunkedgraph/tests/data/sv_area.npy new file mode 100644 index 000000000..b14ac4772 Binary files /dev/null and b/pychunkedgraph/tests/data/sv_area.npy differ diff --git a/pychunkedgraph/tests/data/sv_edges.npy b/pychunkedgraph/tests/data/sv_edges.npy new file mode 100644 index 000000000..d46b8af0b Binary files /dev/null and b/pychunkedgraph/tests/data/sv_edges.npy differ diff --git a/pychunkedgraph/tests/data/sv_sinks.npy b/pychunkedgraph/tests/data/sv_sinks.npy new file mode 100644 index 000000000..bdae7f042 Binary files /dev/null and b/pychunkedgraph/tests/data/sv_sinks.npy differ diff --git a/pychunkedgraph/tests/data/sv_sources.npy b/pychunkedgraph/tests/data/sv_sources.npy new file mode 100644 index 000000000..5200b7f66 Binary files /dev/null and b/pychunkedgraph/tests/data/sv_sources.npy differ diff --git a/pychunkedgraph/tests/helpers.py b/pychunkedgraph/tests/helpers.py index 8097d9c77..de5314422 100644 --- a/pychunkedgraph/tests/helpers.py +++ b/pychunkedgraph/tests/helpers.py @@ -1,17 +1,25 @@ import os import subprocess -from datetime import timedelta -from functools import partial from math import inf -from signal import SIGTERM from time import sleep +from signal import SIGTERM +from functools import reduce +from functools import partial +from datetime import timedelta + -import numpy as np import pytest +import numpy as np from google.auth import credentials from google.cloud import bigtable -from pychunkedgraph.backend import chunkedgraph +from ..ingest.utils import bootstrap +from ..ingest.create.atomic_layer import add_atomic_edges +from ..graph.edges import Edges +from ..graph.edges import EDGE_TYPES +from ..graph.utils import basetypes +from ..graph.chunkedgraph import ChunkedGraph +from ..ingest.create.abstract_layers import add_layer class CloudVolumeBounds(object): @@ -31,7 +39,7 @@ def to_list(self): class CloudVolumeMock(object): def __init__(self): - self.resolution = np.array([1, 1, 1], dtype=np.int) + self.resolution = np.array([1, 1, 1], dtype=int) self.bounds = CloudVolumeBounds() @@ -39,7 +47,9 @@ def setup_emulator_env(): bt_env_init = subprocess.run( ["gcloud", "beta", "emulators", "bigtable", "env-init"], stdout=subprocess.PIPE ) - os.environ["BIGTABLE_EMULATOR_HOST"] = bt_env_init.stdout.decode("utf-8").strip().split("=")[-1] + os.environ["BIGTABLE_EMULATOR_HOST"] = ( + bt_env_init.stdout.decode("utf-8").strip().split("=")[-1] + ) c = bigtable.Client( project="IGNORE_ENVIRONMENT_PROJECT", @@ -60,7 +70,14 @@ def setup_emulator_env(): def bigtable_emulator(request): # Start Emulator bigtable_emulator = subprocess.Popen( - ["gcloud", "beta", "emulators", "bigtable", "start"], + [ + "gcloud", + "beta", + "emulators", + "bigtable", + "start", + "--host-port=localhost:8539", + ], preexec_fn=os.setsid, stdout=subprocess.PIPE, ) @@ -76,7 +93,9 @@ def bigtable_emulator(request): sleep(5) if retries == 0: - print("\nCouldn't start Bigtable Emulator. Make sure it is installed correctly.") + print( + "\nCouldn't start Bigtable Emulator. Make sure it is installed correctly." + ) exit(1) # Setup Emulator-Finalizer @@ -87,46 +106,51 @@ def fin(): request.addfinalizer(fin) -@pytest.fixture(scope="function") -def lock_expired_timedelta_override(request): - # HACK: For the duration of the test, set global LOCK_EXPIRED_TIME_DELTA - # to 1 second (otherwise test would have to run for several minutes) - - original_timedelta = chunkedgraph.LOCK_EXPIRED_TIME_DELTA - - chunkedgraph.LOCK_EXPIRED_TIME_DELTA = timedelta(seconds=1) - - # Ensure that we restore the original value, even if the test fails. - def fin(): - chunkedgraph.LOCK_EXPIRED_TIME_DELTA = original_timedelta - - request.addfinalizer(fin) - return chunkedgraph.LOCK_EXPIRED_TIME_DELTA - - @pytest.fixture(scope="function") def gen_graph(request): - def _cgraph(request, fan_out=2, n_layers=10): - # setup Chunked Graph - dataset_info = {"data_dir": ""} - - graph = chunkedgraph.ChunkedGraph( - request.function.__name__, - project_id="IGNORE_ENVIRONMENT_PROJECT", - credentials=credentials.AnonymousCredentials(), - instance_id="emulated_instance", - dataset_info=dataset_info, - chunk_size=np.array([512, 512, 64], dtype=np.uint64), - is_new=True, - fan_out=np.uint64(fan_out), - n_layers=np.uint64(n_layers), + def _cgraph(request, n_layers=10, atomic_chunk_bounds: np.ndarray = np.array([])): + config = { + "data_source": { + "EDGES": "gs://chunked-graph/minnie65_0/edges", + "COMPONENTS": "gs://chunked-graph/minnie65_0/components", + "WATERSHED": "gs://microns-seunglab/minnie65/ws_minnie65_0", + }, + "graph_config": { + "CHUNK_SIZE": [512, 512, 64], + "FANOUT": 2, + "SPATIAL_BITS": 10, + "ID_PREFIX": "", + "ROOT_LOCK_EXPIRY": timedelta(seconds=5) + }, + "backend_client": { + "TYPE": "bigtable", + "CONFIG": { + "ADMIN": True, + "READ_ONLY": False, + "PROJECT": "IGNORE_ENVIRONMENT_PROJECT", + "INSTANCE": "emulated_instance", + "CREDENTIALS": credentials.AnonymousCredentials(), + "MAX_ROW_KEY_COUNT": 1000 + }, + }, + "ingest_config": {}, + } + + meta, _, client_info = bootstrap("test", config=config) + graph = ChunkedGraph(graph_id="test", meta=meta, + client_info=client_info) + graph.mock_edges = Edges([], []) + graph.meta._ws_cv = CloudVolumeMock() + graph.meta.layer_count = n_layers + graph.meta.layer_chunk_bounds = get_layer_chunk_bounds( + n_layers, atomic_chunk_bounds=atomic_chunk_bounds ) - graph._cv = CloudVolumeMock() + graph.create() # setup Chunked Graph - Finalizer def fin(): - graph.table.delete() + graph.client._table.delete() request.addfinalizer(fin) return graph @@ -152,7 +176,8 @@ def gen_graph_simplequerytest(request, gen_graph): # Chunk B create_chunk( graph, - vertices=[to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 1)], + vertices=[to_label(graph, 1, 1, 0, 0, 0), + to_label(graph, 1, 1, 0, 0, 1)], edges=[ (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 1), 0.5), (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 2, 0, 0, 0), inf), @@ -163,91 +188,115 @@ def gen_graph_simplequerytest(request, gen_graph): create_chunk( graph, vertices=[to_label(graph, 1, 2, 0, 0, 0)], - edges=[(to_label(graph, 1, 2, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf)], + edges=[(to_label(graph, 1, 2, 0, 0, 0), + to_label(graph, 1, 1, 0, 0, 0), inf)], ) - graph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), n_threads=1) - graph.add_layer(3, np.array([[2, 0, 0]]), n_threads=1) - graph.add_layer(4, np.array([[0, 0, 0], [1, 0, 0]]), n_threads=1) + add_layer(graph, 3, [0, 0, 0], n_threads=1) + add_layer(graph, 3, [1, 0, 0], n_threads=1) + add_layer(graph, 4, [0, 0, 0], n_threads=1) return graph -def create_chunk(cgraph, vertices=None, edges=None, timestamp=None): +def create_chunk(cg, vertices=None, edges=None, timestamp=None): """ Helper function to add vertices and edges to the chunkedgraph - no safety checks! """ - if not vertices: - vertices = [] - - if not edges: - edges = [] - + edges = edges if edges else [] + vertices = vertices if vertices else [] vertices = np.unique(np.array(vertices, dtype=np.uint64)) - edges = [(np.uint64(v1), np.uint64(v2), np.float32(aff)) for v1, v2, aff in edges] - - isolated_node_ids = [ + edges = [(np.uint64(v1), np.uint64(v2), np.float32(aff)) + for v1, v2, aff in edges] + isolated_ids = [ x for x in vertices if (x not in [edges[i][0] for i in range(len(edges))]) and (x not in [edges[i][1] for i in range(len(edges))]) ] - edge_ids = { - "in_connected": np.array([], dtype=np.uint64).reshape(0, 2), - "in_disconnected": np.array([], dtype=np.uint64).reshape(0, 2), - "cross": np.array([], dtype=np.uint64).reshape(0, 2), - "between_connected": np.array([], dtype=np.uint64).reshape(0, 2), - "between_disconnected": np.array([], dtype=np.uint64).reshape(0, 2), - } - edge_affs = { - "in_connected": np.array([], dtype=np.float32), - "in_disconnected": np.array([], dtype=np.float32), - "between_connected": np.array([], dtype=np.float32), - "between_disconnected": np.array([], dtype=np.float32), - } + chunk_edges_active = {} + for edge_type in EDGE_TYPES: + chunk_edges_active[edge_type] = Edges([], []) for e in edges: - if cgraph.test_if_nodes_are_in_same_chunk(e[0:2]): - this_edge = np.array([e[0], e[1]], dtype=np.uint64).reshape(-1, 2) - edge_ids["in_connected"] = np.concatenate([edge_ids["in_connected"], this_edge]) - edge_affs["in_connected"] = np.concatenate([edge_affs["in_connected"], [e[2]]]) - - if len(edge_ids["in_connected"]) > 0: - chunk_id = cgraph.get_chunk_id(edge_ids["in_connected"][0][0]) - elif len(vertices) > 0: - chunk_id = cgraph.get_chunk_id(vertices[0]) - else: - chunk_id = None + if cg.get_chunk_id(e[0]) == cg.get_chunk_id(e[1]): + sv1s = np.array([e[0]], dtype=basetypes.NODE_ID) + sv2s = np.array([e[1]], dtype=basetypes.NODE_ID) + affs = np.array([e[2]], dtype=basetypes.EDGE_AFFINITY) + chunk_edges_active[EDGE_TYPES.in_chunk] += Edges( + sv1s, sv2s, affinities=affs + ) + + chunk_id = None + if len(chunk_edges_active[EDGE_TYPES.in_chunk]): + chunk_id = cg.get_chunk_id( + chunk_edges_active[EDGE_TYPES.in_chunk].node_ids1[0]) + elif len(vertices): + chunk_id = cg.get_chunk_id(vertices[0]) for e in edges: - if not cgraph.test_if_nodes_are_in_same_chunk(e[0:2]): + if not cg.get_chunk_id(e[0]) == cg.get_chunk_id(e[1]): # Ensure proper order if chunk_id is not None: - if cgraph.get_chunk_id(e[0]) != chunk_id: + if not chunk_id == cg.get_chunk_id(e[0]): e = [e[1], e[0], e[2]] - this_edge = np.array([e[0], e[1]], dtype=np.uint64).reshape(-1, 2) - + sv1s = np.array([e[0]], dtype=basetypes.NODE_ID) + sv2s = np.array([e[1]], dtype=basetypes.NODE_ID) + affs = np.array([e[2]], dtype=basetypes.EDGE_AFFINITY) if np.isinf(e[2]): - edge_ids["cross"] = np.concatenate([edge_ids["cross"], this_edge]) - else: - edge_ids["between_connected"] = np.concatenate( - [edge_ids["between_connected"], this_edge] + chunk_edges_active[EDGE_TYPES.cross_chunk] += Edges( + sv1s, sv2s, affinities=affs ) - edge_affs["between_connected"] = np.concatenate( - [edge_affs["between_connected"], [e[2]]] + else: + chunk_edges_active[EDGE_TYPES.between_chunk] += Edges( + sv1s, sv2s, affinities=affs ) - isolated_node_ids = np.array(isolated_node_ids, dtype=np.uint64) + all_edges = reduce(lambda x, y: x + y, chunk_edges_active.values()) + cg.mock_edges += all_edges - cgraph.logger.debug(edge_ids) - cgraph.logger.debug(edge_affs) - - # Use affinities as areas - cgraph.add_atomic_edges_in_chunks( - edge_ids, edge_affs, edge_affs, isolated_node_ids, time_stamp=timestamp + isolated_ids = np.array(isolated_ids, dtype=np.uint64) + add_atomic_edges( + cg, + cg.get_chunk_coordinates(chunk_id), + chunk_edges_active, + isolated=isolated_ids, ) -def to_label(cgraph, l, x, y, z, segment_id): - return cgraph.get_node_id(np.uint64(segment_id), layer=l, x=x, y=y, z=z) +def to_label(cg, l, x, y, z, segment_id): + return cg.get_node_id(np.uint64(segment_id), layer=l, x=x, y=y, z=z) + + +def get_layer_chunk_bounds( + n_layers: int, atomic_chunk_bounds: np.ndarray = np.array([]) +) -> dict: + if atomic_chunk_bounds.size == 0: + limit = 2 ** (n_layers - 2) + atomic_chunk_bounds = np.array([limit, limit, limit]) + layer_bounds_d = {} + for layer in range(2, n_layers): + layer_bounds = atomic_chunk_bounds / (2 ** (layer - 2)) + layer_bounds_d[layer] = np.ceil(layer_bounds).astype(int) + return layer_bounds_d + + +@pytest.fixture(scope='session') +def sv_data(): + test_data_dir = 'pychunkedgraph/tests/data' + edges_file = f'{test_data_dir}/sv_edges.npy' + sv_edges = np.load(edges_file) + + source_file = f'{test_data_dir}/sv_sources.npy' + sv_sources = np.load(source_file) + + sinks_file = f'{test_data_dir}/sv_sinks.npy' + sv_sinks = np.load(sinks_file) + + affinity_file = f'{test_data_dir}/sv_affinity.npy' + sv_affinity = np.load(affinity_file) + + area_file = f'{test_data_dir}/sv_area.npy' + sv_area = np.load(area_file) + yield (sv_edges, sv_sources, sv_sinks, sv_affinity, sv_area) diff --git a/pychunkedgraph/tests/test_graphoperation.py b/pychunkedgraph/tests/test_graphoperation.py deleted file mode 100644 index 3ea902727..000000000 --- a/pychunkedgraph/tests/test_graphoperation.py +++ /dev/null @@ -1,246 +0,0 @@ -from collections import namedtuple - -import numpy as np -import pytest - -from pychunkedgraph.backend.graphoperation import ( - GraphEditOperation, - MergeOperation, - MulticutOperation, - RedoOperation, - SplitOperation, - UndoOperation, -) -from pychunkedgraph.backend.utils import column_keys - - -class FakeLogRecords: - Record = namedtuple("graph_op", ("id", "record")) - - _records = [ - { # 0: Merge with coordinates - column_keys.OperationLogs.AddedEdge: np.array([[1, 2]], dtype=np.uint64), - column_keys.OperationLogs.SinkCoordinate: np.array([[1, 2, 3]]), - column_keys.OperationLogs.SourceCoordinate: np.array([[4, 5, 6]]), - column_keys.OperationLogs.UserID: "42", - }, - { # 1: Multicut with coordinates - column_keys.OperationLogs.BoundingBoxOffset: np.array([240, 240, 24]), - column_keys.OperationLogs.RemovedEdge: np.array( - [[1, 3], [4, 1], [1, 5]], dtype=np.uint64 - ), - column_keys.OperationLogs.SinkCoordinate: np.array([[1, 2, 3]]), - column_keys.OperationLogs.SinkID: np.array([1], dtype=np.uint64), - column_keys.OperationLogs.SourceCoordinate: np.array([[4, 5, 6]]), - column_keys.OperationLogs.SourceID: np.array([2], dtype=np.uint64), - column_keys.OperationLogs.UserID: "42", - }, - { # 2: Split with coordinates - column_keys.OperationLogs.RemovedEdge: np.array( - [[1, 3], [4, 1], [1, 5]], dtype=np.uint64 - ), - column_keys.OperationLogs.SinkCoordinate: np.array([[1, 2, 3]]), - column_keys.OperationLogs.SinkID: np.array([1], dtype=np.uint64), - column_keys.OperationLogs.SourceCoordinate: np.array([[4, 5, 6]]), - column_keys.OperationLogs.SourceID: np.array([2], dtype=np.uint64), - column_keys.OperationLogs.UserID: "42", - }, - { # 3: Undo of records[0] - column_keys.OperationLogs.UndoOperationID: np.uint64(0), - column_keys.OperationLogs.UserID: "42", - }, - { # 4: Redo of records[0] - column_keys.OperationLogs.RedoOperationID: np.uint64(0), - column_keys.OperationLogs.UserID: "42", - }, - { # 5: Unknown record - column_keys.OperationLogs.UserID: "42", - }, - ] - - MERGE = Record(id=np.uint64(0), record=_records[0]) - MULTICUT = Record(id=np.uint64(1), record=_records[1]) - SPLIT = Record(id=np.uint64(2), record=_records[2]) - UNDO = Record(id=np.uint64(3), record=_records[3]) - REDO = Record(id=np.uint64(4), record=_records[4]) - UNKNOWN = Record(id=np.uint64(5), record=_records[5]) - - @classmethod - def get(cls, idx: int): - try: - return cls._records[idx] - except IndexError as err: - raise KeyError(err) # Bigtable would throw KeyError instead - - -@pytest.fixture(scope="function") -def cg(mocker): - graph = mocker.MagicMock() - graph.get_chunk_layer = mocker.MagicMock(return_value=1) - graph.read_log_row = mocker.MagicMock(side_effect=FakeLogRecords.get) - return graph - - -def test_read_from_log_merge(mocker, cg): - """MergeOperation should be correctly identified by an existing AddedEdge column. - Coordinates are optional.""" - graph_operation = GraphEditOperation.from_log_record(cg, FakeLogRecords.MERGE.record) - assert isinstance(graph_operation, MergeOperation) - - -def test_read_from_log_multicut(mocker, cg): - """MulticutOperation should be correctly identified by a Sink/Source ID and - BoundingBoxOffset column. Unless requested as SplitOperation...""" - graph_operation = GraphEditOperation.from_log_record( - cg, FakeLogRecords.MULTICUT.record, multicut_as_split=False - ) - assert isinstance(graph_operation, MulticutOperation) - - graph_operation = GraphEditOperation.from_log_record( - cg, FakeLogRecords.MULTICUT.record, multicut_as_split=True - ) - assert isinstance(graph_operation, SplitOperation) - - -def test_read_from_log_split(mocker, cg): - """SplitOperation should be correctly identified by the lack of a - BoundingBoxOffset column.""" - graph_operation = GraphEditOperation.from_log_record(cg, FakeLogRecords.SPLIT.record) - assert isinstance(graph_operation, SplitOperation) - - -def test_read_from_log_undo(mocker, cg): - """UndoOperation should be correctly identified by the UndoOperationID.""" - graph_operation = GraphEditOperation.from_log_record(cg, FakeLogRecords.UNDO.record) - assert isinstance(graph_operation, UndoOperation) - - -def test_read_from_log_redo(mocker, cg): - """RedoOperation should be correctly identified by the RedoOperationID.""" - graph_operation = GraphEditOperation.from_log_record(cg, FakeLogRecords.REDO.record) - assert isinstance(graph_operation, RedoOperation) - - -def test_read_from_log_undo_undo(mocker, cg): - """Undo[Undo[Merge]] -> Redo[Merge]""" - fake_log_record = { - column_keys.OperationLogs.UndoOperationID: np.uint64(FakeLogRecords.UNDO.id), - column_keys.OperationLogs.UserID: "42", - } - - graph_operation = GraphEditOperation.from_log_record(cg, fake_log_record) - assert isinstance(graph_operation, RedoOperation) - assert isinstance(graph_operation.superseded_operation, MergeOperation) - - -def test_read_from_log_undo_redo(mocker, cg): - """Undo[Redo[Merge]] -> Undo[Merge]""" - fake_log_record = { - column_keys.OperationLogs.UndoOperationID: np.uint64(FakeLogRecords.REDO.id), - column_keys.OperationLogs.UserID: "42", - } - - graph_operation = GraphEditOperation.from_log_record(cg, fake_log_record) - assert isinstance(graph_operation, UndoOperation) - assert isinstance(graph_operation.inverse_superseded_operation, SplitOperation) - - -def test_read_from_log_redo_undo(mocker, cg): - """Redo[Undo[Merge]] -> Undo[Merge]""" - fake_log_record = { - column_keys.OperationLogs.RedoOperationID: np.uint64(FakeLogRecords.UNDO.id), - column_keys.OperationLogs.UserID: "42", - } - - graph_operation = GraphEditOperation.from_log_record(cg, fake_log_record) - assert isinstance(graph_operation, UndoOperation) - assert isinstance(graph_operation.inverse_superseded_operation, SplitOperation) - - -def test_read_from_log_redo_redo(mocker, cg): - """Redo[Redo[Merge]] -> Redo[Merge]""" - fake_log_record = { - column_keys.OperationLogs.RedoOperationID: np.uint64(FakeLogRecords.REDO.id), - column_keys.OperationLogs.UserID: "42", - } - - graph_operation = GraphEditOperation.from_log_record(cg, fake_log_record) - assert isinstance(graph_operation, RedoOperation) - assert isinstance(graph_operation.superseded_operation, MergeOperation) - - -def test_invert_merge(mocker, cg): - """Inverse of Merge is a Split""" - graph_operation = GraphEditOperation.from_log_record(cg, FakeLogRecords.MERGE.record) - inverted_graph_operation = graph_operation.invert() - assert isinstance(inverted_graph_operation, SplitOperation) - assert np.all(np.equal(graph_operation.added_edges, inverted_graph_operation.removed_edges)) - - -@pytest.mark.skip(reason="Can't test right now - would require recalculting the Multicut") -def test_invert_multicut(mocker, cg): - """Inverse of a Multicut is a Merge""" - - -def test_invert_split(mocker, cg): - """Inverse of Split is a Merge""" - graph_operation = GraphEditOperation.from_log_record(cg, FakeLogRecords.SPLIT.record) - inverted_graph_operation = graph_operation.invert() - assert isinstance(inverted_graph_operation, MergeOperation) - assert np.all(np.equal(graph_operation.removed_edges, inverted_graph_operation.added_edges)) - - -def test_invert_undo(mocker, cg): - """Inverse of Undo[x] is Redo[x]""" - graph_operation = GraphEditOperation.from_log_record(cg, FakeLogRecords.UNDO.record) - inverted_graph_operation = graph_operation.invert() - assert isinstance(inverted_graph_operation, RedoOperation) - assert ( - graph_operation.superseded_operation_id == inverted_graph_operation.superseded_operation_id - ) - - -def test_invert_redo(mocker, cg): - """Inverse of Redo[x] is Undo[x]""" - graph_operation = GraphEditOperation.from_log_record(cg, FakeLogRecords.REDO.record) - inverted_graph_operation = graph_operation.invert() - assert ( - graph_operation.superseded_operation_id == inverted_graph_operation.superseded_operation_id - ) - - -def test_undo_redo_chain_fails(mocker, cg): - """Prevent creation of Undo/Redo chains""" - with pytest.raises(ValueError): - UndoOperation( - cg, - user_id="DAU", - superseded_operation_id=FakeLogRecords.UNDO.id, - multicut_as_split=False, - ) - with pytest.raises(ValueError): - UndoOperation( - cg, - user_id="DAU", - superseded_operation_id=FakeLogRecords.REDO.id, - multicut_as_split=False, - ) - with pytest.raises(ValueError): - RedoOperation( - cg, - user_id="DAU", - superseded_operation_id=FakeLogRecords.UNDO.id, - multicut_as_split=False, - ) - with pytest.raises(ValueError): - UndoOperation( - cg, - user_id="DAU", - superseded_operation_id=FakeLogRecords.REDO.id, - multicut_as_split=False, - ) - -def test_unknown_log_record_fails(cg, mocker): - """TypeError when encountering unknown log row""" - with pytest.raises(TypeError): - GraphEditOperation.from_log_record(cg, FakeLogRecords.UNKNOWN.record) diff --git a/pychunkedgraph/tests/test_operation.py b/pychunkedgraph/tests/test_operation.py new file mode 100644 index 000000000..ff7cb65bd --- /dev/null +++ b/pychunkedgraph/tests/test_operation.py @@ -0,0 +1,261 @@ +# from collections import namedtuple + +# import numpy as np +# import pytest + +# from ..graph.operation import ( +# GraphEditOperation, +# MergeOperation, +# MulticutOperation, +# RedoOperation, +# SplitOperation, +# UndoOperation, +# ) +# from ..graph import attributes + + +# class FakeLogRecords: +# Record = namedtuple("graph_op", ("id", "record")) + +# _records = [ +# { # 0: Merge with coordinates +# attributes.OperationLogs.AddedEdge: np.array([[1, 2]], dtype=np.uint64), +# attributes.OperationLogs.SinkCoordinate: np.array([[1, 2, 3]]), +# attributes.OperationLogs.SourceCoordinate: np.array([[4, 5, 6]]), +# attributes.OperationLogs.UserID: "42", +# }, +# { # 1: Multicut with coordinates +# attributes.OperationLogs.BoundingBoxOffset: np.array([240, 240, 24]), +# attributes.OperationLogs.RemovedEdge: np.array( +# [[1, 3], [4, 1], [1, 5]], dtype=np.uint64 +# ), +# attributes.OperationLogs.SinkCoordinate: np.array([[1, 2, 3]]), +# attributes.OperationLogs.SinkID: np.array([1], dtype=np.uint64), +# attributes.OperationLogs.SourceCoordinate: np.array([[4, 5, 6]]), +# attributes.OperationLogs.SourceID: np.array([2], dtype=np.uint64), +# attributes.OperationLogs.UserID: "42", +# }, +# { # 2: Split with coordinates +# attributes.OperationLogs.RemovedEdge: np.array( +# [[1, 3], [4, 1], [1, 5]], dtype=np.uint64 +# ), +# attributes.OperationLogs.SinkCoordinate: np.array([[1, 2, 3]]), +# attributes.OperationLogs.SinkID: np.array([1], dtype=np.uint64), +# attributes.OperationLogs.SourceCoordinate: np.array([[4, 5, 6]]), +# attributes.OperationLogs.SourceID: np.array([2], dtype=np.uint64), +# attributes.OperationLogs.UserID: "42", +# }, +# { # 3: Undo of records[0] +# attributes.OperationLogs.UndoOperationID: np.uint64(0), +# attributes.OperationLogs.UserID: "42", +# }, +# { # 4: Redo of records[0] +# attributes.OperationLogs.RedoOperationID: np.uint64(0), +# attributes.OperationLogs.UserID: "42", +# }, +# {attributes.OperationLogs.UserID: "42",}, # 5: Unknown record +# ] + +# MERGE = Record(id=np.uint64(0), record=_records[0]) +# MULTICUT = Record(id=np.uint64(1), record=_records[1]) +# SPLIT = Record(id=np.uint64(2), record=_records[2]) +# UNDO = Record(id=np.uint64(3), record=_records[3]) +# REDO = Record(id=np.uint64(4), record=_records[4]) +# UNKNOWN = Record(id=np.uint64(5), record=_records[5]) + +# @classmethod +# def get(cls, idx: int): +# try: +# return cls._records[idx] +# except IndexError as err: +# raise KeyError(err) # Bigtable would throw KeyError instead + + +# @pytest.fixture(scope="function") +# def cg(mocker): +# graph = mocker.MagicMock() +# graph.get_chunk_layer = mocker.MagicMock(return_value=1) +# graph.read_log_row = mocker.MagicMock(side_effect=FakeLogRecords.get) +# return graph + + +# def test_read_from_log_merge(mocker, cg): +# """MergeOperation should be correctly identified by an existing AddedEdge column. +# Coordinates are optional.""" +# graph_operation = GraphEditOperation.from_log_record( +# cg, FakeLogRecords.MERGE.record +# ) +# assert isinstance(graph_operation, MergeOperation) + + +# def test_read_from_log_multicut(mocker, cg): +# """MulticutOperation should be correctly identified by a Sink/Source ID and +# BoundingBoxOffset column. Unless requested as SplitOperation...""" +# graph_operation = GraphEditOperation.from_log_record( +# cg, FakeLogRecords.MULTICUT.record, multicut_as_split=False +# ) +# assert isinstance(graph_operation, MulticutOperation) + +# graph_operation = GraphEditOperation.from_log_record( +# cg, FakeLogRecords.MULTICUT.record, multicut_as_split=True +# ) +# assert isinstance(graph_operation, SplitOperation) + + +# def test_read_from_log_split(mocker, cg): +# """SplitOperation should be correctly identified by the lack of a +# BoundingBoxOffset column.""" +# graph_operation = GraphEditOperation.from_log_record( +# cg, FakeLogRecords.SPLIT.record +# ) +# assert isinstance(graph_operation, SplitOperation) + + +# def test_read_from_log_undo(mocker, cg): +# """UndoOperation should be correctly identified by the UndoOperationID.""" +# graph_operation = GraphEditOperation.from_log_record(cg, FakeLogRecords.UNDO.record) +# assert isinstance(graph_operation, UndoOperation) + + +# def test_read_from_log_redo(mocker, cg): +# """RedoOperation should be correctly identified by the RedoOperationID.""" +# graph_operation = GraphEditOperation.from_log_record(cg, FakeLogRecords.REDO.record) +# assert isinstance(graph_operation, RedoOperation) + + +# def test_read_from_log_undo_undo(mocker, cg): +# """Undo[Undo[Merge]] -> Redo[Merge]""" +# fake_log_record = { +# attributes.OperationLogs.UndoOperationID: np.uint64(FakeLogRecords.UNDO.id), +# attributes.OperationLogs.UserID: "42", +# } + +# graph_operation = GraphEditOperation.from_log_record(cg, fake_log_record) +# assert isinstance(graph_operation, RedoOperation) +# assert isinstance(graph_operation.superseded_operation, MergeOperation) + + +# def test_read_from_log_undo_redo(mocker, cg): +# """Undo[Redo[Merge]] -> Undo[Merge]""" +# fake_log_record = { +# attributes.OperationLogs.UndoOperationID: np.uint64(FakeLogRecords.REDO.id), +# attributes.OperationLogs.UserID: "42", +# } + +# graph_operation = GraphEditOperation.from_log_record(cg, fake_log_record) +# assert isinstance(graph_operation, UndoOperation) +# assert isinstance(graph_operation.inverse_superseded_operation, SplitOperation) + + +# def test_read_from_log_redo_undo(mocker, cg): +# """Redo[Undo[Merge]] -> Undo[Merge]""" +# fake_log_record = { +# attributes.OperationLogs.RedoOperationID: np.uint64(FakeLogRecords.UNDO.id), +# attributes.OperationLogs.UserID: "42", +# } + +# graph_operation = GraphEditOperation.from_log_record(cg, fake_log_record) +# assert isinstance(graph_operation, UndoOperation) +# assert isinstance(graph_operation.inverse_superseded_operation, SplitOperation) + + +# def test_read_from_log_redo_redo(mocker, cg): +# """Redo[Redo[Merge]] -> Redo[Merge]""" +# fake_log_record = { +# attributes.OperationLogs.RedoOperationID: np.uint64(FakeLogRecords.REDO.id), +# attributes.OperationLogs.UserID: "42", +# } + +# graph_operation = GraphEditOperation.from_log_record(cg, fake_log_record) +# assert isinstance(graph_operation, RedoOperation) +# assert isinstance(graph_operation.superseded_operation, MergeOperation) + + +# def test_invert_merge(mocker, cg): +# """Inverse of Merge is a Split""" +# graph_operation = GraphEditOperation.from_log_record( +# cg, FakeLogRecords.MERGE.record +# ) +# inverted_graph_operation = graph_operation.invert() +# assert isinstance(inverted_graph_operation, SplitOperation) +# assert np.all( +# np.equal(graph_operation.added_edges, inverted_graph_operation.removed_edges) +# ) + + +# @pytest.mark.skip( +# reason="Can't test right now - would require recalculting the Multicut" +# ) +# def test_invert_multicut(mocker, cg): +# """Inverse of a Multicut is a Merge""" + + +# def test_invert_split(mocker, cg): +# """Inverse of Split is a Merge""" +# graph_operation = GraphEditOperation.from_log_record( +# cg, FakeLogRecords.SPLIT.record +# ) +# inverted_graph_operation = graph_operation.invert() +# assert isinstance(inverted_graph_operation, MergeOperation) +# assert np.all( +# np.equal(graph_operation.removed_edges, inverted_graph_operation.added_edges) +# ) + + +# def test_invert_undo(mocker, cg): +# """Inverse of Undo[x] is Redo[x]""" +# graph_operation = GraphEditOperation.from_log_record(cg, FakeLogRecords.UNDO.record) +# inverted_graph_operation = graph_operation.invert() +# assert isinstance(inverted_graph_operation, RedoOperation) +# assert ( +# graph_operation.superseded_operation_id +# == inverted_graph_operation.superseded_operation_id +# ) + + +# def test_invert_redo(mocker, cg): +# """Inverse of Redo[x] is Undo[x]""" +# graph_operation = GraphEditOperation.from_log_record(cg, FakeLogRecords.REDO.record) +# inverted_graph_operation = graph_operation.invert() +# assert ( +# graph_operation.superseded_operation_id +# == inverted_graph_operation.superseded_operation_id +# ) + + +# def test_undo_redo_chain_fails(mocker, cg): +# """Prevent creation of Undo/Redo chains""" +# with pytest.raises(ValueError): +# UndoOperation( +# cg, +# user_id="DAU", +# superseded_operation_id=FakeLogRecords.UNDO.id, +# multicut_as_split=False, +# ) +# with pytest.raises(ValueError): +# UndoOperation( +# cg, +# user_id="DAU", +# superseded_operation_id=FakeLogRecords.REDO.id, +# multicut_as_split=False, +# ) +# with pytest.raises(ValueError): +# RedoOperation( +# cg, +# user_id="DAU", +# superseded_operation_id=FakeLogRecords.UNDO.id, +# multicut_as_split=False, +# ) +# with pytest.raises(ValueError): +# UndoOperation( +# cg, +# user_id="DAU", +# superseded_operation_id=FakeLogRecords.REDO.id, +# multicut_as_split=False, +# ) + + +# def test_unknown_log_record_fails(cg, mocker): +# """TypeError when encountering unknown log row""" +# with pytest.raises(TypeError): +# GraphEditOperation.from_log_record(cg, FakeLogRecords.UNKNOWN.record) diff --git a/pychunkedgraph/tests/test_root_lock.py b/pychunkedgraph/tests/test_root_lock.py index 5ed0862bd..a5ef7d4d2 100644 --- a/pychunkedgraph/tests/test_root_lock.py +++ b/pychunkedgraph/tests/test_root_lock.py @@ -1,98 +1,104 @@ -from unittest.mock import DEFAULT - -import numpy as np -import pytest - -import pychunkedgraph.backend.chunkedgraph_exceptions as cg_exceptions -from pychunkedgraph.backend.root_lock import RootLock - -G_UINT64 = np.uint64(2 ** 63) - - -def big_uint64(): - """Return incremental uint64 values larger than a signed int64""" - global G_UINT64 - if G_UINT64 == np.uint64(2 ** 64 - 1): - G_UINT64 = np.uint64(2 ** 63) - G_UINT64 = G_UINT64 + np.uint64(1) - return G_UINT64 - - -class RootLockTracker: - def __init__(self): - self.active_locks = dict() - - def add_locks(self, root_ids, operation_id, **kwargs): - if operation_id not in self.active_locks: - self.active_locks[operation_id] = set(root_ids) - else: - self.active_locks[operation_id].update(root_ids) - return DEFAULT - - def remove_lock(self, root_id, operation_id, **kwargs): - if operation_id in self.active_locks: - self.active_locks[operation_id].discard(root_id) - return DEFAULT - - -@pytest.fixture() -def root_lock_tracker(): - return RootLockTracker() - - -def test_successful_lock_acquisition(mocker, root_lock_tracker): - """Ensure that root locks got released after successful - root lock acquisition + *successful* graph operation""" - fake_operation_id = big_uint64() - fake_locked_root_ids = np.array((big_uint64(), big_uint64())) - - cg = mocker.MagicMock() - cg.get_unique_operation_id = mocker.MagicMock(return_value=fake_operation_id) - cg.lock_root_loop = mocker.MagicMock( - return_value=(True, fake_locked_root_ids), side_effect=root_lock_tracker.add_locks - ) - cg.unlock_root = mocker.MagicMock(return_value=True, side_effect=root_lock_tracker.remove_lock) - - with RootLock(cg, fake_locked_root_ids): - assert fake_operation_id in root_lock_tracker.active_locks - assert not root_lock_tracker.active_locks[fake_operation_id].difference( - fake_locked_root_ids - ) - - assert not root_lock_tracker.active_locks[fake_operation_id] - - -def test_failed_lock_acquisition(mocker): - """Ensure that LockingError is raised when lock acquisition failed""" - fake_operation_id = big_uint64() - fake_locked_root_ids = np.array((big_uint64(), big_uint64())) - - cg = mocker.MagicMock() - cg.get_unique_operation_id = mocker.MagicMock(return_value=fake_operation_id) - cg.lock_root_loop = mocker.MagicMock( - return_value=(False, fake_locked_root_ids), side_effect=None - ) - - with pytest.raises(cg_exceptions.LockingError): - with RootLock(cg, fake_locked_root_ids): - pass - - -def test_failed_graph_operation(mocker, root_lock_tracker): - """Ensure that root locks got released after successful - root lock acquisition + *unsuccessful* graph operation""" - fake_operation_id = big_uint64() - fake_locked_root_ids = np.array((big_uint64(), big_uint64())) - - cg = mocker.MagicMock() - cg.get_unique_operation_id = mocker.MagicMock(return_value=fake_operation_id) - cg.lock_root_loop = mocker.MagicMock( - return_value=(True, fake_locked_root_ids), side_effect=root_lock_tracker.add_locks - ) - cg.unlock_root = mocker.MagicMock(return_value=True, side_effect=root_lock_tracker.remove_lock) - - with pytest.raises(cg_exceptions.PreconditionError): - with RootLock(cg, fake_locked_root_ids): - raise cg_exceptions.PreconditionError("Something went wrong") - - assert not root_lock_tracker.active_locks[fake_operation_id] +# from unittest.mock import DEFAULT + +# import numpy as np +# import pytest + +# from ..graph import exceptions +# from ..graph.locks import RootLock + +# G_UINT64 = np.uint64(2 ** 63) + + +# def big_uint64(): +# """Return incremental uint64 values larger than a signed int64""" +# global G_UINT64 +# if G_UINT64 == np.uint64(2 ** 64 - 1): +# G_UINT64 = np.uint64(2 ** 63) +# G_UINT64 = G_UINT64 + np.uint64(1) +# return G_UINT64 + + +# class RootLockTracker: +# def __init__(self): +# self.active_locks = dict() + +# def add_locks(self, root_ids, operation_id, **kwargs): +# if operation_id not in self.active_locks: +# self.active_locks[operation_id] = set(root_ids) +# else: +# self.active_locks[operation_id].update(root_ids) +# return DEFAULT + +# def remove_lock(self, root_id, operation_id, **kwargs): +# if operation_id in self.active_locks: +# self.active_locks[operation_id].discard(root_id) +# return DEFAULT + + +# @pytest.fixture() +# def root_lock_tracker(): +# return RootLockTracker() + + +# def test_successful_lock_acquisition(mocker, root_lock_tracker): +# """Ensure that root locks got released after successful +# root lock acquisition + *successful* graph operation""" +# fake_operation_id = big_uint64() +# fake_locked_root_ids = np.array((big_uint64(), big_uint64())) + +# cg = mocker.MagicMock() +# cg.id_client.create_operation_id = mocker.MagicMock(return_value=fake_operation_id) +# cg.client.lock_roots = mocker.MagicMock( +# return_value=(True, fake_locked_root_ids), +# side_effect=root_lock_tracker.add_locks, +# ) +# cg.client.unlock_root = mocker.MagicMock( +# return_value=True, side_effect=root_lock_tracker.remove_lock +# ) + +# with RootLock(cg, fake_locked_root_ids): +# assert fake_operation_id in root_lock_tracker.active_locks +# assert not root_lock_tracker.active_locks[fake_operation_id].difference( +# fake_locked_root_ids +# ) + +# assert not root_lock_tracker.active_locks[fake_operation_id] + + +# def test_failed_lock_acquisition(mocker): +# """Ensure that LockingError is raised when lock acquisition failed""" +# fake_operation_id = big_uint64() +# fake_locked_root_ids = np.array((big_uint64(), big_uint64())) + +# cg = mocker.MagicMock() +# cg.id_client.create_operation_id = mocker.MagicMock(return_value=fake_operation_id) +# cg.client.lock_roots = mocker.MagicMock( +# return_value=(False, fake_locked_root_ids), side_effect=None +# ) + +# with pytest.raises(exceptions.LockingError): +# with RootLock(cg, fake_locked_root_ids): +# pass + + +# def test_failed_graph_operation(mocker, root_lock_tracker): +# """Ensure that root locks got released after successful +# root lock acquisition + *unsuccessful* graph operation""" +# fake_operation_id = big_uint64() +# fake_locked_root_ids = np.array((big_uint64(), big_uint64())) + +# cg = mocker.MagicMock() +# cg.id_client.create_operation_id = mocker.MagicMock(return_value=fake_operation_id) +# cg.client.lock_roots = mocker.MagicMock( +# return_value=(True, fake_locked_root_ids), +# side_effect=root_lock_tracker.add_locks, +# ) +# cg.client.unlock_root = mocker.MagicMock( +# return_value=True, side_effect=root_lock_tracker.remove_lock +# ) + +# with pytest.raises(exceptions.PreconditionError): +# with RootLock(cg, fake_locked_root_ids): +# raise exceptions.PreconditionError("Something went wrong") + +# assert not root_lock_tracker.active_locks[fake_operation_id] diff --git a/pychunkedgraph/tests/test_uncategorized.py b/pychunkedgraph/tests/test_uncategorized.py index 488b87591..93c41158d 100644 --- a/pychunkedgraph/tests/test_uncategorized.py +++ b/pychunkedgraph/tests/test_uncategorized.py @@ -2,11 +2,11 @@ import os import subprocess import sys +from time import sleep from datetime import datetime, timedelta from functools import partial from math import inf from signal import SIGTERM -from time import sleep from unittest import mock from warnings import warn @@ -16,14 +16,27 @@ from google.cloud import bigtable from grpc._channel import _Rendezvous -from helpers import (bigtable_emulator, create_chunk, gen_graph, - gen_graph_simplequerytest, - lock_expired_timedelta_override, to_label) -from pychunkedgraph.backend import chunkedgraph -from pychunkedgraph.backend import chunkedgraph_exceptions as cg_exceptions -from pychunkedgraph.backend.utils import column_keys, serializers -from pychunkedgraph.creator import graph_tests -from pychunkedgraph.meshing import meshgen, meshgen_utils +from .helpers import ( + bigtable_emulator, + create_chunk, + gen_graph, + gen_graph_simplequerytest, + to_label, + sv_data, +) +from ..graph import types +from ..graph import attributes +from ..graph import exceptions +from ..graph import chunkedgraph +from ..graph.edges import Edges +from ..graph.utils import basetypes +from ..graph.misc import get_delta_roots +from ..graph.cutting import run_multicut +from ..graph.lineage import get_root_id_history +from ..graph.lineage import get_future_root_ids +from ..graph.utils.serializers import serialize_uint64 +from ..graph.utils.serializers import deserialize_uint64 +from ..ingest.create.abstract_layers import add_layer class TestGraphNodeConversion: @@ -33,38 +46,46 @@ def test_compute_bitmasks(self): @pytest.mark.timeout(30) def test_node_conversion(self, gen_graph): - cgraph = gen_graph(n_layers=10) + cg = gen_graph(n_layers=10) - node_id = cgraph.get_node_id(np.uint64(4), layer=2, x=3, y=1, z=0) - assert cgraph.get_chunk_layer(node_id) == 2 - assert np.all(cgraph.get_chunk_coordinates(node_id) == np.array([3, 1, 0])) + node_id = cg.get_node_id(np.uint64(4), layer=2, x=3, y=1, z=0) + assert cg.get_chunk_layer(node_id) == 2 + assert np.all(cg.get_chunk_coordinates(node_id) == np.array([3, 1, 0])) - chunk_id = cgraph.get_chunk_id(layer=2, x=3, y=1, z=0) - assert cgraph.get_chunk_layer(chunk_id) == 2 - assert np.all(cgraph.get_chunk_coordinates(chunk_id) == np.array([3, 1, 0])) + chunk_id = cg.get_chunk_id(layer=2, x=3, y=1, z=0) + assert cg.get_chunk_layer(chunk_id) == 2 + assert np.all(cg.get_chunk_coordinates(chunk_id) == np.array([3, 1, 0])) - assert cgraph.get_chunk_id(node_id=node_id) == chunk_id - assert cgraph.get_node_id(np.uint64(4), chunk_id=chunk_id) == node_id + assert cg.get_chunk_id(node_id=node_id) == chunk_id + assert cg.get_node_id(np.uint64(4), chunk_id=chunk_id) == node_id @pytest.mark.timeout(30) def test_node_id_adjacency(self, gen_graph): - cgraph = gen_graph(n_layers=10) + cg = gen_graph(n_layers=10) - assert cgraph.get_node_id(np.uint64(0), layer=2, x=3, y=1, z=0) + np.uint64(1) == \ - cgraph.get_node_id(np.uint64(1), layer=2, x=3, y=1, z=0) + assert cg.get_node_id(np.uint64(0), layer=2, x=3, y=1, z=0) + np.uint64( + 1 + ) == cg.get_node_id(np.uint64(1), layer=2, x=3, y=1, z=0) - assert cgraph.get_node_id(np.uint64(2**53 - 2), layer=10, x=0, y=0, z=0) + np.uint64(1) == \ - cgraph.get_node_id(np.uint64(2**53 - 1), layer=10, x=0, y=0, z=0) + assert cg.get_node_id( + np.uint64(2 ** 53 - 2), layer=10, x=0, y=0, z=0 + ) + np.uint64(1) == cg.get_node_id( + np.uint64(2 ** 53 - 1), layer=10, x=0, y=0, z=0 + ) @pytest.mark.timeout(30) def test_serialize_node_id(self, gen_graph): - cgraph = gen_graph(n_layers=10) + cg = gen_graph(n_layers=10) - assert serializers.serialize_uint64(cgraph.get_node_id(np.uint64(0), layer=2, x=3, y=1, z=0)) < \ - serializers.serialize_uint64(cgraph.get_node_id(np.uint64(1), layer=2, x=3, y=1, z=0)) + assert serialize_uint64( + cg.get_node_id(np.uint64(0), layer=2, x=3, y=1, z=0) + ) < serialize_uint64(cg.get_node_id(np.uint64(1), layer=2, x=3, y=1, z=0)) - assert serializers.serialize_uint64(cgraph.get_node_id(np.uint64(2 ** 53 - 2), layer=10, x=0, y=0, z=0)) < \ - serializers.serialize_uint64(cgraph.get_node_id(np.uint64(2 ** 53 - 1), layer=10, x=0, y=0, z=0)) + assert serialize_uint64( + cg.get_node_id(np.uint64(2 ** 53 - 2), layer=10, x=0, y=0, z=0) + ) < serialize_uint64( + cg.get_node_id(np.uint64(2 ** 53 - 1), layer=10, x=0, y=0, z=0) + ) @pytest.mark.timeout(30) def test_deserialize_node_id(self): @@ -77,8 +98,7 @@ def test_serialization_roundtrip(self): @pytest.mark.timeout(30) def test_serialize_valid_label_id(self): label = np.uint64(0x01FF031234556789) - assert serializers.deserialize_uint64( - serializers.serialize_uint64(label)) == label + assert deserialize_uint64(serialize_uint64(label)) == label class TestGraphBuild: @@ -93,47 +113,33 @@ def test_build_single_node(self, gen_graph): └─────┘ """ - cgraph = gen_graph(n_layers=2) - + cg = gen_graph(n_layers=2) # Add Chunk A - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0)]) + create_chunk(cg, vertices=[to_label(cg, 1, 0, 0, 0, 0)]) - res = cgraph.table.read_rows() + res = cg.client._table.read_rows() res.consume_all() - # Check for the RG-to-CG mapping: - # assert chunkedgraph.serialize_uint64(1) in res.rows - # row = res.rows[chunkedgraph.serialize_uint64(1)].cells[cgraph.family_id] - # assert np.frombuffer(row[b'cg_id'][0].value, np.uint64)[0] == to_label(cgraph, 1, 0, 0, 0, 0) - - # Check for the Level 1 CG supervoxel: - # to_label(cgraph, 1, 0, 0, 0, 0) - assert serializers.serialize_uint64(to_label(cgraph, 1, 0, 0, 0, 0)) in res.rows - atomic_node_info = cgraph.get_atomic_node_info(to_label(cgraph, 1, 0, 0, 0, 0)) - atomic_affinities = atomic_node_info[column_keys.Connectivity.Affinity] - atomic_partners = atomic_node_info[column_keys.Connectivity.Partner] - parents = atomic_node_info[column_keys.Hierarchy.Parent] - - assert len(atomic_partners) == 0 - assert len(atomic_affinities) == 0 - assert len(parents) == 1 and parents[0] == to_label(cgraph, 2, 0, 0, 0, 1) + assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 0)) in res.rows + parent = cg.get_parent(to_label(cg, 1, 0, 0, 0, 0)) + assert parent == to_label(cg, 2, 0, 0, 0, 1) # Check for the one Level 2 node that should have been created. - # to_label(cgraph, 2, 0, 0, 0, 1) - assert serializers.serialize_uint64(to_label(cgraph, 2, 0, 0, 0, 1)) in res.rows - row = res.rows[serializers.serialize_uint64(to_label(cgraph, 2, 0, 0, 0, 1))].cells[cgraph.family_id] - atomic_cross_edge_dict = cgraph.get_atomic_cross_edge_dict(to_label(cgraph, 2, 0, 0, 0, 1)) - column = column_keys.Hierarchy.Child - children = column.deserialize(row[column.key][0].value) + assert serialize_uint64(to_label(cg, 2, 0, 0, 0, 1)) in res.rows + atomic_cross_edge_d = cg.get_atomic_cross_edges( + np.array([to_label(cg, 2, 0, 0, 0, 1)], dtype=basetypes.NODE_ID) + ) + attr = attributes.Hierarchy.Child + row = res.rows[serialize_uint64(to_label(cg, 2, 0, 0, 0, 1))].cells["0"] + children = attr.deserialize(row[attr.key][0].value) - for aces in atomic_cross_edge_dict.values(): + for aces in atomic_cross_edge_d.values(): assert len(aces) == 0 - assert len(children) == 1 and children[0] == to_label(cgraph, 1, 0, 0, 0, 0) - + assert len(children) == 1 and children[0] == to_label(cg, 1, 0, 0, 0, 0) # Make sure there are not any more entries in the table - assert len(res.rows) == 1 + 1 + 1 + 1 + # include counters, meta and version rows + assert len(res.rows) == 1 + 1 + 1 + 1 + 1 @pytest.mark.timeout(30) def test_build_single_edge(self, gen_graph): @@ -146,63 +152,47 @@ def test_build_single_edge(self, gen_graph): └─────┘ """ - cgraph = gen_graph(n_layers=2) + cg = gen_graph(n_layers=2) # Add Chunk A - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1), 0.5)]) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5)], + ) - res = cgraph.table.read_rows() + res = cg.client._table.read_rows() res.consume_all() - # Check for the two RG-to-CG mappings: - # assert chunkedgraph.serialize_uint64(1) in res.rows - # row = res.rows[chunkedgraph.serialize_uint64(1)].cells[cgraph.family_id] - # assert np.frombuffer(row[b'cg_id'][0].value, np.uint64)[0] == to_label(cgraph, 1, 0, 0, 0, 0) - - # assert chunkedgraph.serialize_uint64(2) in res.rows - # row = res.rows[chunkedgraph.serialize_uint64(2)].cells[cgraph.family_id] - # assert np.frombuffer(row[b'cg_id'][0].value, np.uint64)[0] == to_label(cgraph, 1, 0, 0, 0, 1) - - # Check for the two original Level 1 CG supervoxels - # to_label(cgraph, 1, 0, 0, 0, 0) - assert serializers.serialize_uint64(to_label(cgraph, 1, 0, 0, 0, 0)) in res.rows - atomic_node_info = cgraph.get_atomic_node_info(to_label(cgraph, 1, 0, 0, 0, 0)) - atomic_affinities = atomic_node_info[column_keys.Connectivity.Affinity] - atomic_partners = atomic_node_info[column_keys.Connectivity.Partner] - parents = atomic_node_info[column_keys.Hierarchy.Parent] - - assert len(atomic_partners) == 1 and atomic_partners[0] == to_label(cgraph, 1, 0, 0, 0, 1) - assert len(atomic_affinities) == 1 and atomic_affinities[0] == 0.5 - assert len(parents) == 1 and parents[0] == to_label(cgraph, 2, 0, 0, 0, 1) - - # to_label(cgraph, 1, 0, 0, 0, 1) - assert serializers.serialize_uint64(to_label(cgraph, 1, 0, 0, 0, 1)) in res.rows - atomic_node_info = cgraph.get_atomic_node_info(to_label(cgraph, 1, 0, 0, 0, 1)) - atomic_affinities = atomic_node_info[column_keys.Connectivity.Affinity] - atomic_partners = atomic_node_info[column_keys.Connectivity.Partner] - parents = atomic_node_info[column_keys.Hierarchy.Parent] - - assert len(atomic_partners) == 1 and atomic_partners[0] == to_label(cgraph, 1, 0, 0, 0, 0) - assert len(atomic_affinities) == 1 and atomic_affinities[0] == 0.5 - assert len(parents) == 1 and parents[0] == to_label(cgraph, 2, 0, 0, 0, 1) + assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 0)) in res.rows + parent = cg.get_parent(to_label(cg, 1, 0, 0, 0, 0)) + assert parent == to_label(cg, 2, 0, 0, 0, 1) + + assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 1)) in res.rows + parent = cg.get_parent(to_label(cg, 1, 0, 0, 0, 1)) + assert parent == to_label(cg, 2, 0, 0, 0, 1) # Check for the one Level 2 node that should have been created. - assert serializers.serialize_uint64(to_label(cgraph, 2, 0, 0, 0, 1)) in res.rows - row = res.rows[serializers.serialize_uint64(to_label(cgraph, 2, 0, 0, 0, 1))].cells[cgraph.family_id] + assert serialize_uint64(to_label(cg, 2, 0, 0, 0, 1)) in res.rows - atomic_cross_edge_dict = cgraph.get_atomic_cross_edge_dict(to_label(cgraph, 2, 0, 0, 0, 1)) - column = column_keys.Hierarchy.Child - children = column.deserialize(row[column.key][0].value) + atomic_cross_edge_d = cg.get_atomic_cross_edges( + np.array([to_label(cg, 2, 0, 0, 0, 1)], dtype=basetypes.NODE_ID) + ) + attr = attributes.Hierarchy.Child + row = res.rows[serialize_uint64(to_label(cg, 2, 0, 0, 0, 1))].cells["0"] + children = attr.deserialize(row[attr.key][0].value) - for aces in atomic_cross_edge_dict.values(): + for aces in atomic_cross_edge_d.values(): assert len(aces) == 0 - - assert len(children) == 2 and to_label(cgraph, 1, 0, 0, 0, 0) in children and to_label(cgraph, 1, 0, 0, 0, 1) in children + assert ( + len(children) == 2 + and to_label(cg, 1, 0, 0, 0, 0) in children + and to_label(cg, 1, 0, 0, 0, 1) in children + ) # Make sure there are not any more entries in the table - assert len(res.rows) == 2 + 1 + 1 + 1 + # include counters, meta and version rows + assert len(res.rows) == 2 + 1 + 1 + 1 + 1 @pytest.mark.timeout(30) def test_build_single_across_edge(self, gen_graph): @@ -215,104 +205,96 @@ def test_build_single_across_edge(self, gen_graph): └─────┴─────┘ """ - cgraph = gen_graph(n_layers=3) + atomic_chunk_bounds = np.array([2, 1, 1]) + cg = gen_graph(n_layers=3, atomic_chunk_bounds=atomic_chunk_bounds) # Chunk A - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 0), inf)]) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf)], + ) # Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0), inf)]) - - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), n_threads=1) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf)], + ) - res = cgraph.table.read_rows() + add_layer(cg, 3, [0, 0, 0], n_threads=1) + res = cg.client._table.read_rows() res.consume_all() - # Check for the two RG-to-CG mappings: - # assert chunkedgraph.serialize_uint64(1) in res.rows - # row = res.rows[chunkedgraph.serialize_uint64(1)].cells[cgraph.family_id] - # assert np.frombuffer(row[b'cg_id'][0].value, np.uint64)[0] == to_label(cgraph, 1, 0, 0, 0, 0) - - # assert chunkedgraph.serialize_uint64(2) in res.rows - # row = res.rows[chunkedgraph.serialize_uint64(2)].cells[cgraph.family_id] - # assert np.frombuffer(row[b'cg_id'][0].value, np.uint64)[0] == to_label(cgraph, 1, 1, 0, 0, 0) - - # Check for the two original Level 1 CG supervoxels - # to_label(cgraph, 1, 0, 0, 0, 0) - assert serializers.serialize_uint64(to_label(cgraph, 1, 0, 0, 0, 0)) in res.rows - - atomic_node_info = cgraph.get_atomic_node_info(to_label(cgraph, 1, 0, 0, 0, 0)) - atomic_affinities = atomic_node_info[column_keys.Connectivity.Affinity] - atomic_partners = atomic_node_info[column_keys.Connectivity.Partner] - - cgraph.logger.debug(atomic_node_info.keys()) - - parents = atomic_node_info[column_keys.Hierarchy.Parent] + assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 0)) in res.rows + parent = cg.get_parent(to_label(cg, 1, 0, 0, 0, 0)) + assert parent == to_label(cg, 2, 0, 0, 0, 1) - assert len(atomic_partners) == 1 and atomic_partners[0] == to_label(cgraph, 1, 1, 0, 0, 0) - assert len(atomic_affinities) == 1 and atomic_affinities[0] == inf - assert len(parents) == 1 and parents[0] == to_label(cgraph, 2, 0, 0, 0, 1) - - # to_label(cgraph, 1, 1, 0, 0, 0) - assert serializers.serialize_uint64(to_label(cgraph, 1, 1, 0, 0, 0)) in res.rows - - atomic_node_info = cgraph.get_atomic_node_info(to_label(cgraph, 1, 1, 0, 0, 0)) - atomic_affinities = atomic_node_info[column_keys.Connectivity.Affinity] - atomic_partners = atomic_node_info[column_keys.Connectivity.Partner] - parents = atomic_node_info[column_keys.Hierarchy.Parent] - - assert len(atomic_partners) == 1 and atomic_partners[0] == to_label(cgraph, 1, 0, 0, 0, 0) - assert len(atomic_affinities) == 1 and atomic_affinities[0] == inf - assert len(parents) == 1 and parents[0] == to_label(cgraph, 2, 1, 0, 0, 1) + assert serialize_uint64(to_label(cg, 1, 1, 0, 0, 0)) in res.rows + parent = cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) + assert parent == to_label(cg, 2, 1, 0, 0, 1) # Check for the two Level 2 nodes that should have been created. Since Level 2 has the same # dimensions as Level 1, we also expect them to be in different chunks - # to_label(cgraph, 2, 0, 0, 0, 1) - assert serializers.serialize_uint64(to_label(cgraph, 2, 0, 0, 0, 1)) in res.rows - row = res.rows[serializers.serialize_uint64(to_label(cgraph, 2, 0, 0, 0, 1))].cells[cgraph.family_id] - - atomic_cross_edge_dict = cgraph.get_atomic_cross_edge_dict(to_label(cgraph, 2, 0, 0, 0, 1)) - column = column_keys.Hierarchy.Child - children = column.deserialize(row[column.key][0].value) + # to_label(cg, 2, 0, 0, 0, 1) + assert serialize_uint64(to_label(cg, 2, 0, 0, 0, 1)) in res.rows + atomic_cross_edge_d = cg.get_atomic_cross_edges( + np.array([to_label(cg, 2, 0, 0, 0, 1)], dtype=basetypes.NODE_ID) + ) + atomic_cross_edge_d = atomic_cross_edge_d[ + np.uint64(to_label(cg, 2, 0, 0, 0, 1)) + ] - test_ace = np.array([to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 0)], dtype=np.uint64) - assert len(atomic_cross_edge_dict[2]) == 1 - assert test_ace in atomic_cross_edge_dict[2] - assert len(children) == 1 and to_label(cgraph, 1, 0, 0, 0, 0) in children + attr = attributes.Hierarchy.Child + row = res.rows[serialize_uint64(to_label(cg, 2, 0, 0, 0, 1))].cells["0"] + children = attr.deserialize(row[attr.key][0].value) - # to_label(cgraph, 2, 1, 0, 0, 1) - assert serializers.serialize_uint64(to_label(cgraph, 2, 1, 0, 0, 1)) in res.rows - row = res.rows[serializers.serialize_uint64(to_label(cgraph, 2, 1, 0, 0, 1))].cells[cgraph.family_id] + test_ace = np.array( + [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0)], + dtype=np.uint64, + ) + assert len(atomic_cross_edge_d[2]) == 1 + assert test_ace in atomic_cross_edge_d[2] + assert len(children) == 1 and to_label(cg, 1, 0, 0, 0, 0) in children + + # to_label(cg, 2, 1, 0, 0, 1) + assert serialize_uint64(to_label(cg, 2, 1, 0, 0, 1)) in res.rows + atomic_cross_edge_d = cg.get_atomic_cross_edges( + np.array([to_label(cg, 2, 1, 0, 0, 1)], dtype=basetypes.NODE_ID) + ) + atomic_cross_edge_d = atomic_cross_edge_d[ + np.uint64(to_label(cg, 2, 1, 0, 0, 1)) + ] - atomic_cross_edge_dict = cgraph.get_atomic_cross_edge_dict(to_label(cgraph, 2, 1, 0, 0, 1)) - column = column_keys.Hierarchy.Child - children = column.deserialize(row[column.key][0].value) + attr = attributes.Hierarchy.Child + row = res.rows[serialize_uint64(to_label(cg, 2, 1, 0, 0, 1))].cells["0"] + children = attr.deserialize(row[attr.key][0].value) - test_ace = np.array([to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0)], dtype=np.uint64) - assert len(atomic_cross_edge_dict[2]) == 1 - assert test_ace in atomic_cross_edge_dict[2] - assert len(children) == 1 and to_label(cgraph, 1, 1, 0, 0, 0) in children + test_ace = np.array( + [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0)], + dtype=np.uint64, + ) + assert len(atomic_cross_edge_d[2]) == 1 + assert test_ace in atomic_cross_edge_d[2] + assert len(children) == 1 and to_label(cg, 1, 1, 0, 0, 0) in children # Check for the one Level 3 node that should have been created. This one combines the two # connected components of Level 2 - # to_label(cgraph, 3, 0, 0, 0, 1) - assert serializers.serialize_uint64(to_label(cgraph, 3, 0, 0, 0, 1)) in res.rows - row = res.rows[serializers.serialize_uint64(to_label(cgraph, 3, 0, 0, 0, 1))].cells[cgraph.family_id] - atomic_cross_edge_dict = cgraph.get_atomic_cross_edge_dict(to_label(cgraph, 3, 0, 0, 0, 1)) - column = column_keys.Hierarchy.Child - children = column.deserialize(row[column.key][0].value) - - - for aces in atomic_cross_edge_dict.values(): - assert len(aces) == 0 - assert len(children) == 2 and to_label(cgraph, 2, 0, 0, 0, 1) in children and to_label(cgraph, 2, 1, 0, 0, 1) in children + # to_label(cg, 3, 0, 0, 0, 1) + assert serialize_uint64(to_label(cg, 3, 0, 0, 0, 1)) in res.rows + + attr = attributes.Hierarchy.Child + row = res.rows[serialize_uint64(to_label(cg, 3, 0, 0, 0, 1))].cells["0"] + children = attr.deserialize(row[attr.key][0].value) + assert ( + len(children) == 2 + and to_label(cg, 2, 0, 0, 0, 1) in children + and to_label(cg, 2, 1, 0, 0, 1) in children + ) # Make sure there are not any more entries in the table - assert len(res.rows) == 2 + 2 + 1 + 3 + 1 + # include counters, meta and version rows + assert len(res.rows) == 2 + 2 + 1 + 3 + 1 + 1 @pytest.mark.timeout(30) def test_build_single_edge_and_single_across_edge(self, gen_graph): @@ -326,118 +308,103 @@ def test_build_single_edge_and_single_across_edge(self, gen_graph): └─────┴─────┘ """ - cgraph = gen_graph(n_layers=3) + cg = gen_graph(n_layers=3) # Chunk A - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1), 0.5), - (to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 0), inf)]) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf), + ], + ) # Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0), inf)]) - - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), n_threads=1) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf)], + ) - res = cgraph.table.read_rows() + add_layer(cg, 3, np.array([0, 0, 0]), n_threads=1) + res = cg.client._table.read_rows() res.consume_all() - # Check for the three RG-to-CG mappings: - # assert chunkedgraph.serialize_uint64(1) in res.rows - # row = res.rows[chunkedgraph.serialize_uint64(1)].cells[cgraph.family_id] - # assert np.frombuffer(row[b'cg_id'][0].value, np.uint64)[0] == to_label(cgraph, 1, 0, 0, 0, 0) - - # assert chunkedgraph.serialize_uint64(2) in res.rows - # row = res.rows[chunkedgraph.serialize_uint64(2)].cells[cgraph.family_id] - # assert np.frombuffer(row[b'cg_id'][0].value, np.uint64)[0] == to_label(cgraph, 1, 0, 0, 0, 1) - - # assert chunkedgraph.serialize_uint64(3) in res.rows - # row = res.rows[chunkedgraph.serialize_uint64(3)].cells[cgraph.family_id] - # assert np.frombuffer(row[b'cg_id'][0].value, np.uint64)[0] == to_label(cgraph, 1, 1, 0, 0, 0) - - # Check for the three original Level 1 CG supervoxels - # to_label(cgraph, 1, 0, 0, 0, 0) - assert serializers.serialize_uint64(to_label(cgraph, 1, 0, 0, 0, 0)) in res.rows - atomic_node_info = cgraph.get_atomic_node_info(to_label(cgraph, 1, 0, 0, 0, 0)) - atomic_affinities = atomic_node_info[column_keys.Connectivity.Affinity] - atomic_partners = atomic_node_info[column_keys.Connectivity.Partner] - parents = atomic_node_info[column_keys.Hierarchy.Parent] - - assert len(atomic_partners) == 2 and to_label(cgraph, 1, 0, 0, 0, 1) in atomic_partners and to_label(cgraph, 1, 1, 0, 0, 0) in atomic_partners - assert len(atomic_affinities) == 2 - if atomic_partners[0] == to_label(cgraph, 1, 0, 0, 0, 1): - assert atomic_affinities[0] == 0.5 and atomic_affinities[1] == inf - else: - assert atomic_affinities[0] == inf and atomic_affinities[1] == 0.5 - assert len(parents) == 1 and parents[0] == to_label(cgraph, 2, 0, 0, 0, 1) - - # to_label(cgraph, 1, 0, 0, 0, 1) - assert serializers.serialize_uint64(to_label(cgraph, 1, 0, 0, 0, 1)) in res.rows - atomic_node_info = cgraph.get_atomic_node_info(to_label(cgraph, 1, 0, 0, 0, 1)) - atomic_affinities = atomic_node_info[column_keys.Connectivity.Affinity] - atomic_partners = atomic_node_info[column_keys.Connectivity.Partner] - parents = atomic_node_info[column_keys.Hierarchy.Parent] - - assert len(atomic_partners) == 1 and atomic_partners[0] == to_label(cgraph, 1, 0, 0, 0, 0) - assert len(atomic_affinities) == 1 and atomic_affinities[0] == 0.5 - assert len(parents) == 1 and parents[0] == to_label(cgraph, 2, 0, 0, 0, 1) - - # to_label(cgraph, 1, 1, 0, 0, 0) - assert serializers.serialize_uint64(to_label(cgraph, 1, 1, 0, 0, 0)) in res.rows - atomic_node_info = cgraph.get_atomic_node_info(to_label(cgraph, 1, 1, 0, 0, 0)) - atomic_affinities = atomic_node_info[column_keys.Connectivity.Affinity] - atomic_partners = atomic_node_info[column_keys.Connectivity.Partner] - parents = atomic_node_info[column_keys.Hierarchy.Parent] - - assert len(atomic_partners) == 1 and atomic_partners[0] == to_label(cgraph, 1, 0, 0, 0, 0) - assert len(atomic_affinities) == 1 and atomic_affinities[0] == inf - assert len(parents) == 1 and parents[0] == to_label(cgraph, 2, 1, 0, 0, 1) + assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 0)) in res.rows + parent = cg.get_parent(to_label(cg, 1, 0, 0, 0, 0)) + assert parent == to_label(cg, 2, 0, 0, 0, 1) + + # to_label(cg, 1, 0, 0, 0, 1) + assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 1)) in res.rows + parent = cg.get_parent(to_label(cg, 1, 0, 0, 0, 1)) + assert parent == to_label(cg, 2, 0, 0, 0, 1) + + # to_label(cg, 1, 1, 0, 0, 0) + assert serialize_uint64(to_label(cg, 1, 1, 0, 0, 0)) in res.rows + parent = cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) + assert parent == to_label(cg, 2, 1, 0, 0, 1) # Check for the two Level 2 nodes that should have been created. Since Level 2 has the same # dimensions as Level 1, we also expect them to be in different chunks - # to_label(cgraph, 2, 0, 0, 0, 1) - assert serializers.serialize_uint64(to_label(cgraph, 2, 0, 0, 0, 1)) in res.rows - row = res.rows[serializers.serialize_uint64(to_label(cgraph, 2, 0, 0, 0, 1))].cells[cgraph.family_id] - atomic_cross_edge_dict = cgraph.get_atomic_cross_edge_dict(to_label(cgraph, 2, 0, 0, 0, 1)) - column = column_keys.Hierarchy.Child + # to_label(cg, 2, 0, 0, 0, 1) + assert serialize_uint64(to_label(cg, 2, 0, 0, 0, 1)) in res.rows + row = res.rows[serialize_uint64(to_label(cg, 2, 0, 0, 0, 1))].cells["0"] + atomic_cross_edge_d = cg.get_atomic_cross_edges([to_label(cg, 2, 0, 0, 0, 1)]) + atomic_cross_edge_d = atomic_cross_edge_d[ + np.uint64(to_label(cg, 2, 0, 0, 0, 1)) + ] + column = attributes.Hierarchy.Child children = column.deserialize(row[column.key][0].value) - test_ace = np.array([to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 0)], dtype=np.uint64) - assert len(atomic_cross_edge_dict[2]) == 1 - assert test_ace in atomic_cross_edge_dict[2] - assert len(children) == 2 and to_label(cgraph, 1, 0, 0, 0, 0) in children and to_label(cgraph, 1, 0, 0, 0, 1) in children + test_ace = np.array( + [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0)], + dtype=np.uint64, + ) + assert len(atomic_cross_edge_d[2]) == 1 + assert test_ace in atomic_cross_edge_d[2] + assert ( + len(children) == 2 + and to_label(cg, 1, 0, 0, 0, 0) in children + and to_label(cg, 1, 0, 0, 0, 1) in children + ) - # to_label(cgraph, 2, 1, 0, 0, 1) - assert serializers.serialize_uint64(to_label(cgraph, 2, 1, 0, 0, 1)) in res.rows - row = res.rows[serializers.serialize_uint64(to_label(cgraph, 2, 1, 0, 0, 1))].cells[cgraph.family_id] - atomic_cross_edge_dict = cgraph.get_atomic_cross_edge_dict(to_label(cgraph, 2, 1, 0, 0, 1)) - column = column_keys.Hierarchy.Child + # to_label(cg, 2, 1, 0, 0, 1) + assert serialize_uint64(to_label(cg, 2, 1, 0, 0, 1)) in res.rows + row = res.rows[serialize_uint64(to_label(cg, 2, 1, 0, 0, 1))].cells["0"] + atomic_cross_edge_d = cg.get_atomic_cross_edges([to_label(cg, 2, 1, 0, 0, 1)]) + atomic_cross_edge_d = atomic_cross_edge_d[ + np.uint64(to_label(cg, 2, 1, 0, 0, 1)) + ] children = column.deserialize(row[column.key][0].value) - test_ace = np.array([to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0)], dtype=np.uint64) - assert len(atomic_cross_edge_dict[2]) == 1 - assert test_ace in atomic_cross_edge_dict[2] - assert len(children) == 1 and to_label(cgraph, 1, 1, 0, 0, 0) in children + test_ace = np.array( + [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0)], + dtype=np.uint64, + ) + assert len(atomic_cross_edge_d[2]) == 1 + assert test_ace in atomic_cross_edge_d[2] + assert len(children) == 1 and to_label(cg, 1, 1, 0, 0, 0) in children # Check for the one Level 3 node that should have been created. This one combines the two # connected components of Level 2 - # to_label(cgraph, 3, 0, 0, 0, 1) - assert serializers.serialize_uint64(to_label(cgraph, 3, 0, 0, 0, 1)) in res.rows - row = res.rows[serializers.serialize_uint64(to_label(cgraph, 3, 0, 0, 0, 1))].cells[cgraph.family_id] - atomic_cross_edge_dict = cgraph.get_atomic_cross_edge_dict(to_label(cgraph, 3, 0, 0, 0, 1)) - column = column_keys.Hierarchy.Child + # to_label(cg, 3, 0, 0, 0, 1) + assert serialize_uint64(to_label(cg, 3, 0, 0, 0, 1)) in res.rows + row = res.rows[serialize_uint64(to_label(cg, 3, 0, 0, 0, 1))].cells["0"] + column = attributes.Hierarchy.Child children = column.deserialize(row[column.key][0].value) - for ace in atomic_cross_edge_dict.values(): - assert len(ace) == 0 - assert len(children) == 2 and to_label(cgraph, 2, 0, 0, 0, 1) in children and to_label(cgraph, 2, 1, 0, 0, 1) in children + assert ( + len(children) == 2 + and to_label(cg, 2, 0, 0, 0, 1) in children + and to_label(cg, 2, 1, 0, 0, 1) in children + ) # Make sure there are not any more entries in the table - assert len(res.rows) == 3 + 2 + 1 + 3 + 1 + # include counters, meta and version rows + assert len(res.rows) == 3 + 2 + 1 + 3 + 1 + 1 - @pytest.mark.timeout(30) + @pytest.mark.timeout(120) def test_build_big_graph(self, gen_graph): """ Create graph with RG nodes 1 and 2 in opposite corners of the largest possible dataset @@ -448,47 +415,27 @@ def test_build_big_graph(self, gen_graph): └─────┘ └─────┘ """ - cgraph = gen_graph(n_layers=10) + atomic_chunk_bounds = np.array([8, 8, 8]) + cg = gen_graph(n_layers=5, atomic_chunk_bounds=atomic_chunk_bounds) # Preparation: Build Chunk A - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0)], - edges=[]) + create_chunk(cg, vertices=[to_label(cg, 1, 0, 0, 0, 0)], edges=[]) # Preparation: Build Chunk Z - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 255, 255, 255, 0)], - edges=[]) - - cgraph.add_layer(3, np.array([[0x00, 0x00, 0x00]]), n_threads=1) - cgraph.add_layer(3, np.array([[0xFF, 0xFF, 0xFF]]), n_threads=1) - cgraph.add_layer(4, np.array([[0x00, 0x00, 0x00]]), n_threads=1) - cgraph.add_layer(4, np.array([[0x7F, 0x7F, 0x7F]]), n_threads=1) - cgraph.add_layer(5, np.array([[0x00, 0x00, 0x00]]), n_threads=1) - cgraph.add_layer(5, np.array([[0x3F, 0x3F, 0x3F]]), n_threads=1) - cgraph.add_layer(6, np.array([[0x00, 0x00, 0x00]]), n_threads=1) - cgraph.add_layer(6, np.array([[0x1F, 0x1F, 0x1F]]), n_threads=1) - cgraph.add_layer(7, np.array([[0x00, 0x00, 0x00]]), n_threads=1) - cgraph.add_layer(7, np.array([[0x0F, 0x0F, 0x0F]]), n_threads=1) - cgraph.add_layer(8, np.array([[0x00, 0x00, 0x00]]), n_threads=1) - cgraph.add_layer(8, np.array([[0x07, 0x07, 0x07]]), n_threads=1) - cgraph.add_layer(9, np.array([[0x00, 0x00, 0x00]]), n_threads=1) - cgraph.add_layer(9, np.array([[0x03, 0x03, 0x03]]), n_threads=1) - cgraph.add_layer(10, np.array([[0x00, 0x00, 0x00], [0x01, 0x01, 0x01]]), n_threads=1) - - res = cgraph.table.read_rows() - res.consume_all() + create_chunk(cg, vertices=[to_label(cg, 1, 7, 7, 7, 0)], edges=[]) - # cgraph.logger.debug(len(res.rows)) - # for row_key in res.rows.keys(): - # cgraph.logger.debug(row_key) - # cgraph.logger.debug(cgraph.get_chunk_layer(chunkedgraph.deserialize_uint64(row_key))) - # cgraph.logger.debug(cgraph.get_chunk_coordinates(chunkedgraph.deserialize_uint64(row_key))) + add_layer(cg, 3, [0, 0, 0], n_threads=1) + add_layer(cg, 3, [3, 3, 3], n_threads=1) + add_layer(cg, 4, [0, 0, 0], n_threads=1) + add_layer(cg, 5, [0, 0, 0], n_threads=1) + + res = cg.client._table.read_rows() + res.consume_all() - assert serializers.serialize_uint64(to_label(cgraph, 1, 0, 0, 0, 0)) in res.rows - assert serializers.serialize_uint64(to_label(cgraph, 1, 255, 255, 255, 0)) in res.rows - assert serializers.serialize_uint64(to_label(cgraph, 10, 0, 0, 0, 1)) in res.rows - assert serializers.serialize_uint64(to_label(cgraph, 10, 0, 0, 0, 2)) in res.rows + assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 0)) in res.rows + assert serialize_uint64(to_label(cg, 1, 7, 7, 7, 0)) in res.rows + assert serialize_uint64(to_label(cg, 5, 0, 0, 0, 1)) in res.rows + assert serialize_uint64(to_label(cg, 5, 0, 0, 0, 2)) in res.rows @pytest.mark.timeout(30) def test_double_chunk_creation(self, gen_graph): @@ -501,41 +448,62 @@ def test_double_chunk_creation(self, gen_graph): └─────┴─────┘ """ - cgraph = gen_graph(n_layers=4) + atomic_chunk_bounds = np.array([4, 4, 4]) + cg = gen_graph(n_layers=4, atomic_chunk_bounds=atomic_chunk_bounds) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 1), - to_label(cgraph, 1, 0, 0, 0, 2)], - edges=[], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], + edges=[], + timestamp=fake_timestamp, + ) # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 1)], - edges=[], - timestamp=fake_timestamp) - - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), - time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), - time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(4, np.array([[0, 0, 0]]), - time_stamp=fake_timestamp, n_threads=1) - - assert len(cgraph.range_read_chunk(layer=2, x=0, y=0, z=0)) == 2 - assert len(cgraph.range_read_chunk(layer=2, x=1, y=0, z=0)) == 1 - assert len(cgraph.range_read_chunk(layer=3, x=0, y=0, z=0)) == 0 - assert len(cgraph.range_read_chunk(layer=4, x=0, y=0, z=0)) == 6 - - assert cgraph.get_chunk_layer(cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 1))) == 4 - assert cgraph.get_chunk_layer(cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 2))) == 4 - assert cgraph.get_chunk_layer(cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 1))) == 4 - - root_seg_ids = [cgraph.get_segment_id(cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 1))), - cgraph.get_segment_id(cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 2))), - cgraph.get_segment_id(cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 1)))] + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 1)], + edges=[], + timestamp=fake_timestamp, + ) + + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + add_layer( + cg, + 4, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + + assert len(cg.range_read_chunk(cg.get_chunk_id(layer=2, x=0, y=0, z=0))) == 2 + assert len(cg.range_read_chunk(cg.get_chunk_id(layer=2, x=1, y=0, z=0))) == 1 + assert len(cg.range_read_chunk(cg.get_chunk_id(layer=3, x=0, y=0, z=0))) == 0 + assert len(cg.range_read_chunk(cg.get_chunk_id(layer=4, x=0, y=0, z=0))) == 6 + + assert cg.get_chunk_layer(cg.get_root(to_label(cg, 1, 0, 0, 0, 1))) == 4 + assert cg.get_chunk_layer(cg.get_root(to_label(cg, 1, 0, 0, 0, 2))) == 4 + assert cg.get_chunk_layer(cg.get_root(to_label(cg, 1, 1, 0, 0, 1))) == 4 + + root_seg_ids = [ + cg.get_segment_id(cg.get_root(to_label(cg, 1, 0, 0, 0, 1))), + cg.get_segment_id(cg.get_root(to_label(cg, 1, 0, 0, 0, 2))), + cg.get_segment_id(cg.get_root(to_label(cg, 1, 1, 0, 0, 1))), + ] assert 4 in root_seg_ids assert 5 in root_seg_ids @@ -550,41 +518,64 @@ class TestGraphSimpleQueries: │ │ │ │ 3: 1 1 0 0 1 ─┘ │ └─────┴─────┴─────┘ 4: 1 2 0 0 0 ─── 2 2 0 0 1 ─── 3 1 0 0 1 ─┘ """ + @pytest.mark.timeout(30) def test_get_parent_and_children(self, gen_graph_simplequerytest): - cgraph = gen_graph_simplequerytest + cg = gen_graph_simplequerytest - children10000 = cgraph.get_children(to_label(cgraph, 1, 0, 0, 0, 0)) - children11000 = cgraph.get_children(to_label(cgraph, 1, 1, 0, 0, 0)) - children11001 = cgraph.get_children(to_label(cgraph, 1, 1, 0, 0, 1)) - children12000 = cgraph.get_children(to_label(cgraph, 1, 2, 0, 0, 0)) + children10000 = cg.get_children(to_label(cg, 1, 0, 0, 0, 0)) + children11000 = cg.get_children(to_label(cg, 1, 1, 0, 0, 0)) + children11001 = cg.get_children(to_label(cg, 1, 1, 0, 0, 1)) + children12000 = cg.get_children(to_label(cg, 1, 2, 0, 0, 0)) - parent10000 = cgraph.get_parent(to_label(cgraph, 1, 0, 0, 0, 0), get_only_relevant_parent=True, time_stamp=None) - parent11000 = cgraph.get_parent(to_label(cgraph, 1, 1, 0, 0, 0), get_only_relevant_parent=True, time_stamp=None) - parent11001 = cgraph.get_parent(to_label(cgraph, 1, 1, 0, 0, 1), get_only_relevant_parent=True, time_stamp=None) - parent12000 = cgraph.get_parent(to_label(cgraph, 1, 2, 0, 0, 0), get_only_relevant_parent=True, time_stamp=None) + parent10000 = cg.get_parent( + to_label(cg, 1, 0, 0, 0, 0), + ) + parent11000 = cg.get_parent( + to_label(cg, 1, 1, 0, 0, 0), + ) + parent11001 = cg.get_parent( + to_label(cg, 1, 1, 0, 0, 1), + ) + parent12000 = cg.get_parent( + to_label(cg, 1, 2, 0, 0, 0), + ) - children20001 = cgraph.get_children(to_label(cgraph, 2, 0, 0, 0, 1)) - children21001 = cgraph.get_children(to_label(cgraph, 2, 1, 0, 0, 1)) - children22001 = cgraph.get_children(to_label(cgraph, 2, 2, 0, 0, 1)) + children20001 = cg.get_children(to_label(cg, 2, 0, 0, 0, 1)) + children21001 = cg.get_children(to_label(cg, 2, 1, 0, 0, 1)) + children22001 = cg.get_children(to_label(cg, 2, 2, 0, 0, 1)) - parent20001 = cgraph.get_parent(to_label(cgraph, 2, 0, 0, 0, 1), get_only_relevant_parent=True, time_stamp=None) - parent21001 = cgraph.get_parent(to_label(cgraph, 2, 1, 0, 0, 1), get_only_relevant_parent=True, time_stamp=None) - parent22001 = cgraph.get_parent(to_label(cgraph, 2, 2, 0, 0, 1), get_only_relevant_parent=True, time_stamp=None) + parent20001 = cg.get_parent( + to_label(cg, 2, 0, 0, 0, 1), + ) + parent21001 = cg.get_parent( + to_label(cg, 2, 1, 0, 0, 1), + ) + parent22001 = cg.get_parent( + to_label(cg, 2, 2, 0, 0, 1), + ) - children30001 = cgraph.get_children(to_label(cgraph, 3, 0, 0, 0, 1)) - # children30002 = cgraph.get_children(to_label(cgraph, 3, 0, 0, 0, 2)) - children31001 = cgraph.get_children(to_label(cgraph, 3, 1, 0, 0, 1)) + children30001 = cg.get_children(to_label(cg, 3, 0, 0, 0, 1)) + # children30002 = cg.get_children(to_label(cg, 3, 0, 0, 0, 2)) + children31001 = cg.get_children(to_label(cg, 3, 1, 0, 0, 1)) - parent30001 = cgraph.get_parent(to_label(cgraph, 3, 0, 0, 0, 1), get_only_relevant_parent=True, time_stamp=None) - # parent30002 = cgraph.get_parent(to_label(cgraph, 3, 0, 0, 0, 2), get_only_relevant_parent=True, time_stamp=None) - parent31001 = cgraph.get_parent(to_label(cgraph, 3, 1, 0, 0, 1), get_only_relevant_parent=True, time_stamp=None) + parent30001 = cg.get_parent( + to_label(cg, 3, 0, 0, 0, 1), + ) + # parent30002 = cg.get_parent(to_label(cg, 3, 0, 0, 0, 2), ) + parent31001 = cg.get_parent( + to_label(cg, 3, 1, 0, 0, 1), + ) - children40001 = cgraph.get_children(to_label(cgraph, 4, 0, 0, 0, 1)) - children40002 = cgraph.get_children(to_label(cgraph, 4, 0, 0, 0, 2)) + children40001 = cg.get_children(to_label(cg, 4, 0, 0, 0, 1)) + children40002 = cg.get_children(to_label(cg, 4, 0, 0, 0, 2)) - parent40001 = cgraph.get_parent(to_label(cgraph, 4, 0, 0, 0, 1), get_only_relevant_parent=True, time_stamp=None) - parent40002 = cgraph.get_parent(to_label(cgraph, 4, 0, 0, 0, 2), get_only_relevant_parent=True, time_stamp=None) + parent40001 = cg.get_parent( + to_label(cg, 4, 0, 0, 0, 1), + ) + parent40002 = cg.get_parent( + to_label(cg, 4, 0, 0, 0, 2), + ) # (non-existing) Children of L1 assert np.array_equal(children10000, []) is True @@ -593,30 +584,39 @@ def test_get_parent_and_children(self, gen_graph_simplequerytest): assert np.array_equal(children12000, []) is True # Parent of L1 - assert parent10000 == to_label(cgraph, 2, 0, 0, 0, 1) - assert parent11000 == to_label(cgraph, 2, 1, 0, 0, 1) - assert parent11001 == to_label(cgraph, 2, 1, 0, 0, 1) - assert parent12000 == to_label(cgraph, 2, 2, 0, 0, 1) + assert parent10000 == to_label(cg, 2, 0, 0, 0, 1) + assert parent11000 == to_label(cg, 2, 1, 0, 0, 1) + assert parent11001 == to_label(cg, 2, 1, 0, 0, 1) + assert parent12000 == to_label(cg, 2, 2, 0, 0, 1) # Children of L2 - assert len(children20001) == 1 and to_label(cgraph, 1, 0, 0, 0, 0) in children20001 - assert len(children21001) == 2 and to_label(cgraph, 1, 1, 0, 0, 0) in children21001 and to_label(cgraph, 1, 1, 0, 0, 1) in children21001 - assert len(children22001) == 1 and to_label(cgraph, 1, 2, 0, 0, 0) in children22001 + assert len(children20001) == 1 and to_label(cg, 1, 0, 0, 0, 0) in children20001 + assert ( + len(children21001) == 2 + and to_label(cg, 1, 1, 0, 0, 0) in children21001 + and to_label(cg, 1, 1, 0, 0, 1) in children21001 + ) + assert len(children22001) == 1 and to_label(cg, 1, 2, 0, 0, 0) in children22001 # Parent of L2 - assert parent20001 == to_label(cgraph, 4, 0, 0, 0, 1) - assert parent21001 == to_label(cgraph, 3, 0, 0, 0, 1) - assert parent22001 == to_label(cgraph, 3, 1, 0, 0, 1) + assert parent20001 == to_label(cg, 4, 0, 0, 0, 1) + assert parent21001 == to_label(cg, 3, 0, 0, 0, 1) + assert parent22001 == to_label(cg, 3, 1, 0, 0, 1) # Children of L3 assert len(children30001) == 1 and len(children31001) == 1 - assert to_label(cgraph, 2, 1, 0, 0, 1) in children30001 - assert to_label(cgraph, 2, 2, 0, 0, 1) in children31001 + assert to_label(cg, 2, 1, 0, 0, 1) in children30001 + assert to_label(cg, 2, 2, 0, 0, 1) in children31001 # Parent of L3 assert parent30001 == parent31001 - assert (parent30001 == to_label(cgraph, 4, 0, 0, 0, 1) and parent20001 == to_label(cgraph, 4, 0, 0, 0, 2)) or \ - (parent30001 == to_label(cgraph, 4, 0, 0, 0, 2) and parent20001 == to_label(cgraph, 4, 0, 0, 0, 1)) + assert ( + parent30001 == to_label(cg, 4, 0, 0, 0, 1) + and parent20001 == to_label(cg, 4, 0, 0, 0, 2) + ) or ( + parent30001 == to_label(cg, 4, 0, 0, 0, 2) + and parent20001 == to_label(cg, 4, 0, 0, 0, 1) + ) # Children of L4 assert parent10000 in children40001 @@ -626,167 +626,137 @@ def test_get_parent_and_children(self, gen_graph_simplequerytest): assert parent40001 is None assert parent40002 is None - # # Children of (non-existing) L5 - # with pytest.raises(IndexError): - # cgraph.get_children(to_label(cgraph, 5, 0, 0, 0, 1)) - - # # Parent of (non-existing) L5 - # with pytest.raises(IndexError): - # cgraph.get_parent(to_label(cgraph, 5, 0, 0, 0, 1), get_only_relevant_parent=True, time_stamp=None) - - children2_separate = cgraph.get_children([to_label(cgraph, 2, 0, 0, 0, 1), - to_label(cgraph, 2, 1, 0, 0, 1), - to_label(cgraph, 2, 2, 0, 0, 1)]) + children2_separate = cg.get_children( + [ + to_label(cg, 2, 0, 0, 0, 1), + to_label(cg, 2, 1, 0, 0, 1), + to_label(cg, 2, 2, 0, 0, 1), + ] + ) assert len(children2_separate) == 3 - assert to_label(cgraph, 2, 0, 0, 0, 1) in children2_separate and \ - np.all(np.isin(children2_separate[to_label(cgraph, 2, 0, 0, 0, 1)], children20001)) - assert to_label(cgraph, 2, 1, 0, 0, 1) in children2_separate and \ - np.all(np.isin(children2_separate[to_label(cgraph, 2, 1, 0, 0, 1)], children21001)) - assert to_label(cgraph, 2, 2, 0, 0, 1) in children2_separate and \ - np.all(np.isin(children2_separate[to_label(cgraph, 2, 2, 0, 0, 1)], children22001)) - - children2_combined = cgraph.get_children([to_label(cgraph, 2, 0, 0, 0, 1), - to_label(cgraph, 2, 1, 0, 0, 1), - to_label(cgraph, 2, 2, 0, 0, 1)], flatten=True) - assert len(children2_combined) == 4 and \ - np.all(np.isin(children20001, children2_combined)) and \ - np.all(np.isin(children21001, children2_combined)) and \ - np.all(np.isin(children22001, children2_combined)) + assert to_label(cg, 2, 0, 0, 0, 1) in children2_separate and np.all( + np.isin(children2_separate[to_label(cg, 2, 0, 0, 0, 1)], children20001) + ) + assert to_label(cg, 2, 1, 0, 0, 1) in children2_separate and np.all( + np.isin(children2_separate[to_label(cg, 2, 1, 0, 0, 1)], children21001) + ) + assert to_label(cg, 2, 2, 0, 0, 1) in children2_separate and np.all( + np.isin(children2_separate[to_label(cg, 2, 2, 0, 0, 1)], children22001) + ) + + children2_combined = cg.get_children( + [ + to_label(cg, 2, 0, 0, 0, 1), + to_label(cg, 2, 1, 0, 0, 1), + to_label(cg, 2, 2, 0, 0, 1), + ], + flatten=True, + ) + assert ( + len(children2_combined) == 4 + and np.all(np.isin(children20001, children2_combined)) + and np.all(np.isin(children21001, children2_combined)) + and np.all(np.isin(children22001, children2_combined)) + ) @pytest.mark.timeout(30) def test_get_root(self, gen_graph_simplequerytest): - cgraph = gen_graph_simplequerytest - - root10000 = cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0), - time_stamp=None) - - root11000 = cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0), - time_stamp=None) - - root11001 = cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 1), - time_stamp=None) - - root12000 = cgraph.get_root(to_label(cgraph, 1, 2, 0, 0, 0), - time_stamp=None) + cg = gen_graph_simplequerytest + root10000 = cg.get_root( + to_label(cg, 1, 0, 0, 0, 0), + ) + root11000 = cg.get_root( + to_label(cg, 1, 1, 0, 0, 0), + ) + root11001 = cg.get_root( + to_label(cg, 1, 1, 0, 0, 1), + ) + root12000 = cg.get_root( + to_label(cg, 1, 2, 0, 0, 0), + ) - with pytest.raises(Exception) as e: - cgraph.get_root(0) + with pytest.raises(Exception): + cg.get_root(0) - assert (root10000 == to_label(cgraph, 4, 0, 0, 0, 1) and - root11000 == root11001 == root12000 == to_label( - cgraph, 4, 0, 0, 0, 2)) or \ - (root10000 == to_label(cgraph, 4, 0, 0, 0, 2) and - root11000 == root11001 == root12000 == to_label( - cgraph, 4, 0, 0, 0, 1)) + assert ( + root10000 == to_label(cg, 4, 0, 0, 0, 1) + and root11000 == root11001 == root12000 == to_label(cg, 4, 0, 0, 0, 2) + ) or ( + root10000 == to_label(cg, 4, 0, 0, 0, 2) + and root11000 == root11001 == root12000 == to_label(cg, 4, 0, 0, 0, 1) + ) @pytest.mark.timeout(30) def test_get_subgraph_nodes(self, gen_graph_simplequerytest): - cgraph = gen_graph_simplequerytest - root1 = cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)) - root2 = cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0)) + cg = gen_graph_simplequerytest + root1 = cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) + root2 = cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) - lvl1_nodes_1 = cgraph.get_subgraph_nodes(root1) - lvl1_nodes_2 = cgraph.get_subgraph_nodes(root2) + lvl1_nodes_1 = cg.get_subgraph([root1], leaves_only=True) + lvl1_nodes_2 = cg.get_subgraph([root2], leaves_only=True) assert len(lvl1_nodes_1) == 1 assert len(lvl1_nodes_2) == 3 - assert to_label(cgraph, 1, 0, 0, 0, 0) in lvl1_nodes_1 - assert to_label(cgraph, 1, 1, 0, 0, 0) in lvl1_nodes_2 - assert to_label(cgraph, 1, 1, 0, 0, 1) in lvl1_nodes_2 - assert to_label(cgraph, 1, 2, 0, 0, 0) in lvl1_nodes_2 - - lvl2_nodes_1 = cgraph.get_subgraph_nodes(root1, return_layers=[2]) - lvl2_nodes_2 = cgraph.get_subgraph_nodes(root2, return_layers=[2]) - assert len(lvl2_nodes_1) == 1 - assert len(lvl2_nodes_2) == 2 - assert to_label(cgraph, 2, 0, 0, 0, 1) in lvl2_nodes_1 - assert to_label(cgraph, 2, 1, 0, 0, 1) in lvl2_nodes_2 - assert to_label(cgraph, 2, 2, 0, 0, 1) in lvl2_nodes_2 - - lvl3_nodes_1 = cgraph.get_subgraph_nodes(root1, return_layers=[3]) - lvl3_nodes_2 = cgraph.get_subgraph_nodes(root2, return_layers=[3]) - assert len(lvl3_nodes_1) == 1 - assert len(lvl3_nodes_2) == 2 - assert to_label(cgraph, 2, 0, 0, 0, 1) in lvl3_nodes_1 - assert to_label(cgraph, 3, 0, 0, 0, 1) in lvl3_nodes_2 - assert to_label(cgraph, 3, 1, 0, 0, 1) in lvl3_nodes_2 - - lvl4_node = cgraph.get_subgraph_nodes(root1, return_layers=[4]) - assert len(lvl4_node) == 1 - assert root1 in lvl4_node - - layers = cgraph.get_subgraph_nodes(root2, return_layers=[1, 4]) - assert len(layers) == 2 and 1 in layers and 4 in layers - assert len(layers[4]) == 1 and root2 in layers[4] - assert len(layers[1]) == 3 - assert to_label(cgraph, 1, 1, 0, 0, 0) in layers[1] - assert to_label(cgraph, 1, 1, 0, 0, 1) in layers[1] - assert to_label(cgraph, 1, 2, 0, 0, 0) in layers[1] - - lvl2_nodes = cgraph.get_subgraph_nodes(root2, return_layers=[2], - bounding_box=[[1, 0, 0], [2, 1, 1]], - bb_is_coordinate=False) - assert len(lvl2_nodes) == 1 - assert to_label(cgraph, 2, 1, 0, 0, 1) in lvl2_nodes - - lvl2_parent = cgraph.get_parent(to_label(cgraph, 1, 1, 0, 0, 0)) - lvl1_nodes = cgraph.get_subgraph_nodes(lvl2_parent) + assert to_label(cg, 1, 0, 0, 0, 0) in lvl1_nodes_1 + assert to_label(cg, 1, 1, 0, 0, 0) in lvl1_nodes_2 + assert to_label(cg, 1, 1, 0, 0, 1) in lvl1_nodes_2 + assert to_label(cg, 1, 2, 0, 0, 0) in lvl1_nodes_2 + + lvl2_parent = cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) + lvl1_nodes = cg.get_subgraph([lvl2_parent], leaves_only=True) assert len(lvl1_nodes) == 2 - assert to_label(cgraph, 1, 1, 0, 0, 0) in lvl1_nodes - assert to_label(cgraph, 1, 1, 0, 0, 1) in lvl1_nodes + assert to_label(cg, 1, 1, 0, 0, 0) in lvl1_nodes + assert to_label(cg, 1, 1, 0, 0, 1) in lvl1_nodes @pytest.mark.timeout(30) def test_get_subgraph_edges(self, gen_graph_simplequerytest): - cgraph = gen_graph_simplequerytest - root1 = cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)) - root2 = cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0)) - - edges, affinities, areas = cgraph.get_subgraph_edges(root1) - assert len(edges) == 0 and len(affinities) == 0 and len(areas) == 0 - - edges, affinities, areas = cgraph.get_subgraph_edges(root2) - - assert [to_label(cgraph, 1, 1, 0, 0, 0), - to_label(cgraph, 1, 1, 0, 0, 1)] in edges or \ - [to_label(cgraph, 1, 1, 0, 0, 1), - to_label(cgraph, 1, 1, 0, 0, 0)] in edges - - assert [to_label(cgraph, 1, 1, 0, 0, 0), - to_label(cgraph, 1, 2, 0, 0, 0)] in edges or \ - [to_label(cgraph, 1, 2, 0, 0, 0), - to_label(cgraph, 1, 1, 0, 0, 0)] in edges - - # assert len(edges) == 2 and len(affinities) == 2 and len(areas) == 2 - - lvl2_parent = cgraph.get_parent(to_label(cgraph, 1, 1, 0, 0, 0)) - edges, affinities, areas = cgraph.get_subgraph_edges(lvl2_parent) - assert [to_label(cgraph, 1, 1, 0, 0, 0), - to_label(cgraph, 1, 1, 0, 0, 1)] in edges or \ - [to_label(cgraph, 1, 1, 0, 0, 1), - to_label(cgraph, 1, 1, 0, 0, 0)] in edges - - assert [to_label(cgraph, 1, 1, 0, 0, 0), - to_label(cgraph, 1, 2, 0, 0, 0)] in edges or \ - [to_label(cgraph, 1, 2, 0, 0, 0), - to_label(cgraph, 1, 1, 0, 0, 0)] in edges - - assert len(edges) == 2 + cg = gen_graph_simplequerytest + root1 = cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) + root2 = cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) + + edges = cg.get_subgraph([root1], edges_only=True) + assert len(edges) == 0 + + edges = cg.get_subgraph([root2], edges_only=True) + assert [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1)] in edges or [ + to_label(cg, 1, 1, 0, 0, 1), + to_label(cg, 1, 1, 0, 0, 0), + ] in edges + + assert [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 0)] in edges or [ + to_label(cg, 1, 2, 0, 0, 0), + to_label(cg, 1, 1, 0, 0, 0), + ] in edges + + lvl2_parent = cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) + edges = cg.get_subgraph([lvl2_parent], edges_only=True) + assert [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1)] in edges or [ + to_label(cg, 1, 1, 0, 0, 1), + to_label(cg, 1, 1, 0, 0, 0), + ] in edges + + assert [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 0)] in edges or [ + to_label(cg, 1, 2, 0, 0, 0), + to_label(cg, 1, 1, 0, 0, 0), + ] in edges + + assert len(edges) == 1 @pytest.mark.timeout(30) def test_get_subgraph_nodes_bb(self, gen_graph_simplequerytest): - cgraph = gen_graph_simplequerytest - - bb = np.array([[1, 0, 0], [2, 1, 1]], dtype=np.int) - bb_coord = bb * cgraph.chunk_size - - childs_1 = cgraph.get_subgraph_nodes(cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 1)), bounding_box=bb) - childs_2 = cgraph.get_subgraph_nodes(cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 1)), bounding_box=bb_coord, bb_is_coordinate=True) - + cg = gen_graph_simplequerytest + bb = np.array([[1, 0, 0], [2, 1, 1]], dtype=int) + bb_coord = bb * cg.meta.graph_config.CHUNK_SIZE + childs_1 = cg.get_subgraph( + [cg.get_root(to_label(cg, 1, 1, 0, 0, 1))], bbox=bb, leaves_only=True + ) + childs_2 = cg.get_subgraph( + [cg.get_root(to_label(cg, 1, 1, 0, 0, 1))], + bbox=bb_coord, + bbox_is_coordinate=True, + leaves_only=True, + ) assert np.all(~(np.sort(childs_1) - np.sort(childs_2))) - @pytest.mark.timeout(30) - def test_get_atomic_partners(self, gen_graph_simplequerytest): - cgraph = gen_graph_simplequerytest - class TestGraphMerge: @pytest.mark.timeout(30) @@ -801,32 +771,35 @@ def test_merge_pair_same_chunk(self, gen_graph): └─────┘ └─────┘ """ - cgraph = gen_graph(n_layers=2) + atomic_chunk_bounds = np.array([1, 1, 1]) + cg = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1)], - edges=[], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[], + timestamp=fake_timestamp, + ) # Merge - new_root_ids = cgraph.add_edges("Jane Doe", [to_label(cgraph, 1, 0, 0, 0, 1), to_label(cgraph, 1, 0, 0, 0, 0)], affinities=0.3).new_root_ids + new_root_ids = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 0)], + affinities=[0.3], + ).new_root_ids assert len(new_root_ids) == 1 new_root_id = new_root_ids[0] # Check - assert cgraph.get_parent(to_label(cgraph, 1, 0, 0, 0, 0)) == new_root_id - assert cgraph.get_parent(to_label(cgraph, 1, 0, 0, 0, 1)) == new_root_id - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 0)) - assert partners[0] == to_label(cgraph, 1, 0, 0, 0, 1) and affinities[0] == np.float32(0.3) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 1)) - assert partners[0] == to_label(cgraph, 1, 0, 0, 0, 0) and affinities[0] == np.float32(0.3) - leaves = np.unique(cgraph.get_subgraph_nodes(new_root_id)) + assert cg.get_parent(to_label(cg, 1, 0, 0, 0, 0)) == new_root_id + assert cg.get_parent(to_label(cg, 1, 0, 0, 0, 1)) == new_root_id + leaves = np.unique(cg.get_subgraph([new_root_id], leaves_only=True)) assert len(leaves) == 2 - assert to_label(cgraph, 1, 0, 0, 0, 0) in leaves - assert to_label(cgraph, 1, 0, 0, 0, 1) in leaves + assert to_label(cg, 1, 0, 0, 0, 0) in leaves + assert to_label(cg, 1, 0, 0, 0, 1) in leaves @pytest.mark.timeout(30) def test_merge_pair_neighboring_chunks(self, gen_graph): @@ -839,42 +812,52 @@ def test_merge_pair_neighboring_chunks(self, gen_graph): └─────┴─────┘ └─────┴─────┘ """ - cgraph = gen_graph(n_layers=3) + cg = gen_graph(n_layers=3) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), time_stamp=fake_timestamp, n_threads=1) + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) # Merge - new_root_ids = cgraph.add_edges("Jane Doe", [to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0)], affinities=0.3).new_root_ids + new_root_ids = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0)], + affinities=0.3, + ).new_root_ids assert len(new_root_ids) == 1 new_root_id = new_root_ids[0] # Check - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)) == new_root_id - assert cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0)) == new_root_id - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 0)) - assert partners[0] == to_label(cgraph, 1, 1, 0, 0, 0) and affinities[0] == np.float32(0.3) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 1, 0, 0, 0)) - assert partners[0] == to_label(cgraph, 1, 0, 0, 0, 0) and affinities[0] == np.float32(0.3) - leaves = np.unique(cgraph.get_subgraph_nodes(new_root_id)) + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == new_root_id + assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == new_root_id + leaves = np.unique(cg.get_subgraph([new_root_id], leaves_only=True)) assert len(leaves) == 2 - assert to_label(cgraph, 1, 0, 0, 0, 0) in leaves - assert to_label(cgraph, 1, 1, 0, 0, 0) in leaves + assert to_label(cg, 1, 0, 0, 0, 0) in leaves + assert to_label(cg, 1, 1, 0, 0, 0) in leaves - @pytest.mark.timeout(30) + @pytest.mark.timeout(120) def test_merge_pair_disconnected_chunks(self, gen_graph): """ Add edge between existing RG supervoxels 1 and 2 (disconnected chunks) @@ -885,41 +868,64 @@ def test_merge_pair_disconnected_chunks(self, gen_graph): └─────┘ └─────┘ └─────┘ └─────┘ """ - cgraph = gen_graph(n_layers=9) + cg = gen_graph(n_layers=5) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) # Preparation: Build Chunk Z - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 127, 127, 127, 0)], - edges=[], - timestamp=fake_timestamp) - - cgraph.add_layer(3, np.array([[0x00, 0x00, 0x00]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(3, np.array([[0x7F, 0x7F, 0x7F]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(4, np.array([[0x00, 0x00, 0x00]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(4, np.array([[0x3F, 0x3F, 0x3F]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(5, np.array([[0x00, 0x00, 0x00]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(5, np.array([[0x1F, 0x1F, 0x1F]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(6, np.array([[0x00, 0x00, 0x00]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(6, np.array([[0x0F, 0x0F, 0x0F]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(7, np.array([[0x00, 0x00, 0x00]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(7, np.array([[0x07, 0x07, 0x07]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(8, np.array([[0x00, 0x00, 0x00]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(8, np.array([[0x03, 0x03, 0x03]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(9, np.array([[0x00, 0x00, 0x00], [0x01, 0x01, 0x01]]), time_stamp=fake_timestamp, n_threads=1) + create_chunk( + cg, + vertices=[to_label(cg, 1, 7, 7, 7, 0)], + edges=[], + timestamp=fake_timestamp, + ) + + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + add_layer( + cg, + 3, + [3, 3, 3], + time_stamp=fake_timestamp, + n_threads=1, + ) + add_layer( + cg, + 4, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + add_layer( + cg, + 5, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) # Merge - result = cgraph.add_edges("Jane Doe", [to_label(cgraph, 1, 127, 127, 127, 0), to_label(cgraph, 1, 0, 0, 0, 0)], affinities=0.3) + result = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 7, 7, 7, 0), to_label(cg, 1, 0, 0, 0, 0)], + affinities=[0.3], + ) new_root_ids, lvl2_node_ids = result.new_root_ids, result.new_lvl2_ids print(f"lvl2_node_ids: {lvl2_node_ids}") - u_layers = np.unique(cgraph.get_chunk_layers(lvl2_node_ids)) + u_layers = np.unique(cg.get_chunk_layers(lvl2_node_ids)) assert len(u_layers) == 1 assert u_layers[0] == 2 @@ -927,22 +933,18 @@ def test_merge_pair_disconnected_chunks(self, gen_graph): new_root_id = new_root_ids[0] # Check - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)) == new_root_id - assert cgraph.get_root(to_label(cgraph, 1, 127, 127, 127, 0)) == new_root_id - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 0)) - assert partners[0] == to_label(cgraph, 1, 127, 127, 127, 0) and affinities[0] == np.float32(0.3) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 127, 127, 127, 0)) - assert partners[0] == to_label(cgraph, 1, 0, 0, 0, 0) and affinities[0] == np.float32(0.3) - leaves = np.unique(cgraph.get_subgraph_nodes(new_root_id)) + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == new_root_id + assert cg.get_root(to_label(cg, 1, 7, 7, 7, 0)) == new_root_id + leaves = np.unique(cg.get_subgraph(new_root_id, leaves_only=True)) assert len(leaves) == 2 - assert to_label(cgraph, 1, 0, 0, 0, 0) in leaves - assert to_label(cgraph, 1, 127, 127, 127, 0) in leaves + assert to_label(cg, 1, 0, 0, 0, 0) in leaves + assert to_label(cg, 1, 7, 7, 7, 0) in leaves @pytest.mark.timeout(30) def test_merge_pair_already_connected(self, gen_graph): """ Add edge between already connected RG supervoxels 1 and 2 (same chunk). - Expected: No change, i.e. same parent (to_label(cgraph, 2, 0, 0, 0, 1)), affinity (0.5) and timestamp as before + Expected: No change, i.e. same parent (to_label(cg, 2, 0, 0, 0, 1)), affinity (0.5) and timestamp as before ┌─────┐ ┌─────┐ │ A¹ │ │ A¹ │ │ 1━2 │ => │ 1━2 │ @@ -950,27 +952,35 @@ def test_merge_pair_already_connected(self, gen_graph): └─────┘ └─────┘ """ - cgraph = gen_graph(n_layers=2) + cg = gen_graph(n_layers=2) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1), 0.5)], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5)], + timestamp=fake_timestamp, + ) - res_old = cgraph.table.read_rows() + res_old = cg.client._table.read_rows() res_old.consume_all() # Merge - cgraph.add_edges("Jane Doe", [to_label(cgraph, 1, 0, 0, 0, 1), to_label(cgraph, 1, 0, 0, 0, 0)]) - res_new = cgraph.table.read_rows() + with pytest.raises(Exception): + cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 0)], + ) + res_new = cg.client._table.read_rows() res_new.consume_all() # Check if res_old.rows != res_new.rows: - warn("Rows were modified when merging a pair of already connected supervoxels. " - "While probably not an error, it is an unnecessary operation.") + warn( + "Rows were modified when merging a pair of already connected supervoxels. " + "While probably not an error, it is an unnecessary operation." + ) @pytest.mark.timeout(30) def test_merge_triple_chain_to_full_circle_same_chunk(self, gen_graph): @@ -983,43 +993,31 @@ def test_merge_triple_chain_to_full_circle_same_chunk(self, gen_graph): └─────┘ └─────┘ """ - cgraph = gen_graph(n_layers=2) + cg = gen_graph(n_layers=2) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1), to_label(cgraph, 1, 0, 0, 0, 2)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 2), 0.5), - (to_label(cgraph, 1, 0, 0, 0, 1), to_label(cgraph, 1, 0, 0, 0, 2), 0.5)], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[ + to_label(cg, 1, 0, 0, 0, 0), + to_label(cg, 1, 0, 0, 0, 1), + to_label(cg, 1, 0, 0, 0, 2), + ], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 2), 0.5), + (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2), 0.5), + ], + timestamp=fake_timestamp, + ) # Merge - new_root_ids = cgraph.add_edges("Jane Doe", [to_label(cgraph, 1, 0, 0, 0, 1), to_label(cgraph, 1, 0, 0, 0, 0)], affinities=0.3).new_root_ids - - assert len(new_root_ids) == 1 - new_root_id = new_root_ids[0] - - # Check - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)) == new_root_id - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 1)) == new_root_id - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 2)) == new_root_id - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 0)) - assert len(partners) == 2 - assert to_label(cgraph, 1, 0, 0, 0, 1) in partners - assert to_label(cgraph, 1, 0, 0, 0, 2) in partners - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 1)) - assert len(partners) == 2 - assert to_label(cgraph, 1, 0, 0, 0, 0) in partners - assert to_label(cgraph, 1, 0, 0, 0, 2) in partners - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 2)) - assert len(partners) == 2 - assert to_label(cgraph, 1, 0, 0, 0, 0) in partners - assert to_label(cgraph, 1, 0, 0, 0, 1) in partners - leaves = np.unique(cgraph.get_subgraph_nodes(new_root_id)) - assert len(leaves) == 3 - assert to_label(cgraph, 1, 0, 0, 0, 0) in leaves - assert to_label(cgraph, 1, 0, 0, 0, 1) in leaves - assert to_label(cgraph, 1, 0, 0, 0, 2) in leaves + with pytest.raises(Exception): + cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 0)], + affinities=0.3, + ).new_root_ids @pytest.mark.timeout(30) def test_merge_triple_chain_to_full_circle_neighboring_chunks(self, gen_graph): @@ -1032,63 +1030,45 @@ def test_merge_triple_chain_to_full_circle_neighboring_chunks(self, gen_graph): └─────┴─────┘ └─────┴─────┘ """ - cgraph = gen_graph(n_layers=3) + cg = gen_graph(n_layers=3) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1), 0.5), - (to_label(cgraph, 1, 0, 0, 0, 1), to_label(cgraph, 1, 1, 0, 0, 0), inf)], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), + (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_timestamp, + ) # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1), inf)], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), inf)], + timestamp=fake_timestamp, + ) - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), time_stamp=fake_timestamp, n_threads=1) + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) # Merge - new_root_ids = cgraph.add_edges("Jane Doe", [to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0)], affinities=1.0).new_root_ids - - assert len(new_root_ids) == 1 - new_root_id = new_root_ids[0] - - # Check - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)) == new_root_id - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 1)) == new_root_id - assert cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0)) == new_root_id - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 0)) - assert len(partners) == 2 - assert to_label(cgraph, 1, 0, 0, 0, 1) in partners - assert to_label(cgraph, 1, 1, 0, 0, 0) in partners - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 1)) - assert len(partners) == 2 - assert to_label(cgraph, 1, 0, 0, 0, 0) in partners - assert to_label(cgraph, 1, 1, 0, 0, 0) in partners - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 1, 0, 0, 0)) - assert len(partners) == 2 - assert to_label(cgraph, 1, 0, 0, 0, 0) in partners - assert to_label(cgraph, 1, 0, 0, 0, 1) in partners - leaves = np.unique(cgraph.get_subgraph_nodes(new_root_id)) - assert len(leaves) == 3 - assert to_label(cgraph, 1, 0, 0, 0, 0) in leaves - assert to_label(cgraph, 1, 0, 0, 0, 1) in leaves - assert to_label(cgraph, 1, 1, 0, 0, 0) in leaves - - cross_edge_dict_layers = graph_tests.root_cross_edge_test(new_root_id, cg=cgraph) # dict: layer -> cross_edge_dict - n_cross_edges_layer = collections.defaultdict(list) - - for child_layer in cross_edge_dict_layers.keys(): - for layer in cross_edge_dict_layers[child_layer].keys(): - n_cross_edges_layer[layer].append(len(cross_edge_dict_layers[child_layer][layer])) - - for layer in n_cross_edges_layer.keys(): - assert len(np.unique(n_cross_edges_layer[layer])) == 1 - - @pytest.mark.timeout(30) + with pytest.raises(Exception): + cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0)], + affinities=1.0, + ).new_root_ids + + @pytest.mark.timeout(120) def test_merge_triple_chain_to_full_circle_disconnected_chunks(self, gen_graph): """ Add edge between indirectly connected RG supervoxels 1 and 2 (disconnected chunks) @@ -1099,73 +1079,93 @@ def test_merge_triple_chain_to_full_circle_disconnected_chunks(self, gen_graph): └─────┘ └─────┘ └─────┘ └─────┘ """ - cgraph = gen_graph(n_layers=9) + cg = gen_graph(n_layers=5) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1), 0.5), - (to_label(cgraph, 1, 0, 0, 0, 1), to_label(cgraph, 1, 127, 127, 127, 0), inf)], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), + ( + to_label(cg, 1, 0, 0, 0, 1), + to_label(cg, 1, 7, 7, 7, 0), + inf, + ), + ], + timestamp=fake_timestamp, + ) # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 127, 127, 127, 0)], - edges=[(to_label(cgraph, 1, 127, 127, 127, 0), to_label(cgraph, 1, 0, 0, 0, 1), inf)], - timestamp=fake_timestamp) - - cgraph.add_layer(3, np.array([[0x00, 0x00, 0x00]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(3, np.array([[0x7F, 0x7F, 0x7F]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(4, np.array([[0x00, 0x00, 0x00]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(4, np.array([[0x3F, 0x3F, 0x3F]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(5, np.array([[0x00, 0x00, 0x00]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(5, np.array([[0x1F, 0x1F, 0x1F]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(6, np.array([[0x00, 0x00, 0x00]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(6, np.array([[0x0F, 0x0F, 0x0F]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(7, np.array([[0x00, 0x00, 0x00]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(7, np.array([[0x07, 0x07, 0x07]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(8, np.array([[0x00, 0x00, 0x00]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(8, np.array([[0x03, 0x03, 0x03]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(9, np.array([[0x00, 0x00, 0x00], [0x01, 0x01, 0x01]]), time_stamp=fake_timestamp, n_threads=1) + create_chunk( + cg, + vertices=[to_label(cg, 1, 7, 7, 7, 0)], + edges=[ + ( + to_label(cg, 1, 7, 7, 7, 0), + to_label(cg, 1, 0, 0, 0, 1), + inf, + ) + ], + timestamp=fake_timestamp, + ) + + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + add_layer( + cg, + 3, + [3, 3, 3], + time_stamp=fake_timestamp, + n_threads=1, + ) + add_layer( + cg, + 4, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + add_layer( + cg, + 4, + [1, 1, 1], + time_stamp=fake_timestamp, + n_threads=1, + ) + add_layer( + cg, + 5, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) # Merge - new_root_ids = cgraph.add_edges("Jane Doe", [to_label(cgraph, 1, 127, 127, 127, 0), to_label(cgraph, 1, 0, 0, 0, 0)], affinities=1.0).new_root_ids + new_root_ids = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 7, 7, 7, 0), to_label(cg, 1, 0, 0, 0, 0)], + affinities=1.0, + ).new_root_ids assert len(new_root_ids) == 1 new_root_id = new_root_ids[0] # Check - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)) == new_root_id - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 1)) == new_root_id - assert cgraph.get_root(to_label(cgraph, 1, 127, 127, 127, 0)) == new_root_id - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 0)) - assert len(partners) == 2 - assert to_label(cgraph, 1, 0, 0, 0, 1) in partners - assert to_label(cgraph, 1, 127, 127, 127, 0) in partners - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 1)) - assert len(partners) == 2 - assert to_label(cgraph, 1, 0, 0, 0, 0) in partners - assert to_label(cgraph, 1, 127, 127, 127, 0) in partners - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 127, 127, 127, 0)) - assert len(partners) == 2 - assert to_label(cgraph, 1, 0, 0, 0, 0) in partners - assert to_label(cgraph, 1, 0, 0, 0, 1) in partners - leaves = np.unique(cgraph.get_subgraph_nodes(new_root_id)) + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == new_root_id + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) == new_root_id + assert cg.get_root(to_label(cg, 1, 7, 7, 7, 0)) == new_root_id + leaves = np.unique(cg.get_subgraph(new_root_id, leaves_only=True)) assert len(leaves) == 3 - assert to_label(cgraph, 1, 0, 0, 0, 0) in leaves - assert to_label(cgraph, 1, 0, 0, 0, 1) in leaves - assert to_label(cgraph, 1, 127, 127, 127, 0) in leaves - - cross_edge_dict_layers = graph_tests.root_cross_edge_test(new_root_id, cg=cgraph) # dict: layer -> cross_edge_dict - n_cross_edges_layer = collections.defaultdict(list) - - for child_layer in cross_edge_dict_layers.keys(): - for layer in cross_edge_dict_layers[child_layer].keys(): - n_cross_edges_layer[layer].append(len(cross_edge_dict_layers[child_layer][layer])) - - for layer in n_cross_edges_layer.keys(): - assert len(np.unique(n_cross_edges_layer[layer])) == 1 + assert to_label(cg, 1, 0, 0, 0, 0) in leaves + assert to_label(cg, 1, 0, 0, 0, 1) in leaves + assert to_label(cg, 1, 7, 7, 7, 0) in leaves @pytest.mark.timeout(30) def test_merge_same_node(self, gen_graph): @@ -1178,23 +1178,28 @@ def test_merge_same_node(self, gen_graph): └─────┘ """ - cgraph = gen_graph(n_layers=2) + cg = gen_graph(n_layers=2) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) - res_old = cgraph.table.read_rows() + res_old = cg.client._table.read_rows() res_old.consume_all() # Merge - with pytest.raises(cg_exceptions.PreconditionError): - cgraph.add_edges("Jane Doe", [to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0)]) + with pytest.raises(Exception): + cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0)], + ) - res_new = cgraph.table.read_rows() + res_new = cg.client._table.read_rows() res_new.consume_all() assert res_new.rows == res_old.rows @@ -1215,31 +1220,44 @@ def test_merge_pair_abstract_nodes(self, gen_graph): └─────┘ """ - cgraph = gen_graph(n_layers=3) + cg = gen_graph(n_layers=3) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), time_stamp=fake_timestamp, n_threads=1) + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) - res_old = cgraph.table.read_rows() + res_old = cg.client._table.read_rows() res_old.consume_all() # Merge - with pytest.raises(cg_exceptions.PreconditionError): - cgraph.add_edges("Jane Doe", [to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 2, 1, 0, 0, 1)]) + with pytest.raises(Exception): + cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 2, 1, 0, 0, 1)], + ) - res_new = cgraph.table.read_rows() + res_new = cg.client._table.read_rows() res_new.consume_all() assert res_new.rows == res_old.rows @@ -1260,1031 +1278,284 @@ def test_diagonal_connections(self, gen_graph): └─────┴─────┘ """ - cgraph = gen_graph(n_layers=3) + cg = gen_graph(n_layers=3) # Chunk A - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0), - to_label(cgraph, 1, 0, 0, 0, 1)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), - to_label(cgraph, 1, 1, 0, 0, 0), inf), - (to_label(cgraph, 1, 0, 0, 0, 0), - to_label(cgraph, 1, 0, 1, 0, 0), inf)]) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf), + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 1, 0, 0), inf), + ], + ) # Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 1, 0, 0, 0), - to_label(cgraph, 1, 0, 0, 0, 0), inf)]) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf)], + ) # Chunk C - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 1, 0, 0)], - edges=[(to_label(cgraph, 1, 0, 1, 0, 0), - to_label(cgraph, 1, 1, 1, 0, 0), inf), - (to_label(cgraph, 1, 0, 1, 0, 0), - to_label(cgraph, 1, 0, 0, 0, 0), inf)]) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 1, 0, 0)], + edges=[ + (to_label(cg, 1, 0, 1, 0, 0), to_label(cg, 1, 1, 1, 0, 0), inf), + (to_label(cg, 1, 0, 1, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf), + ], + ) # Chunk D - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 1, 0, 0)], - edges=[(to_label(cgraph, 1, 1, 1, 0, 0), - to_label(cgraph, 1, 0, 1, 0, 0), inf)]) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 1, 0, 0)], + edges=[(to_label(cg, 1, 1, 1, 0, 0), to_label(cg, 1, 0, 1, 0, 0), inf)], + ) - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]]), n_threads=1) + add_layer( + cg, + 3, + [0, 0, 0], + n_threads=1, + ) - rr = cgraph.range_read_chunk( - chunk_id=cgraph.get_chunk_id(layer=3, x=0, y=0, z=0)) + rr = cg.range_read_chunk(chunk_id=cg.get_chunk_id(layer=3, x=0, y=0, z=0)) root_ids_t0 = list(rr.keys()) assert len(root_ids_t0) == 2 child_ids = [] for root_id in root_ids_t0: - cgraph.logger.debug(("root_id", root_id)) - child_ids.extend(cgraph.get_subgraph_nodes(root_id)) + child_ids.extend(cg.get_subgraph(root_id, leaves_only=True)) - new_roots = cgraph.add_edges("Jane Doe", - [to_label(cgraph, 1, 0, 0, 0, 0), - to_label(cgraph, 1, 0, 0, 0, 1)], - affinities=[.5]).new_root_ids + new_roots = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + affinities=[0.5], + ).new_root_ids root_ids = [] for child_id in child_ids: - root_ids.append(cgraph.get_root(child_id)) + root_ids.append(cg.get_root(child_id)) assert len(np.unique(root_ids)) == 1 root_id = root_ids[0] assert root_id == new_roots[0] - cross_edge_dict_layers = graph_tests.root_cross_edge_test(root_id, - cg=cgraph) # dict: layer -> cross_edge_dict - n_cross_edges_layer = collections.defaultdict(list) - - for child_layer in cross_edge_dict_layers.keys(): - for layer in cross_edge_dict_layers[child_layer].keys(): - n_cross_edges_layer[layer].append( - len(cross_edge_dict_layers[child_layer][layer])) - - for layer in n_cross_edges_layer.keys(): - assert len(np.unique(n_cross_edges_layer[layer])) == 1 - - @pytest.mark.timeout(30) + @pytest.mark.timeout(240) def test_cross_edges(self, gen_graph): """""" - cgraph = gen_graph(n_layers=6) - - chunk_offset = 6 + cg = gen_graph(n_layers=5) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, chunk_offset, 0, 0, 0), to_label(cgraph, 1, chunk_offset, 0, 0, 1), - to_label(cgraph, 1, chunk_offset, 0, 0, 2), to_label(cgraph, 1, chunk_offset, 0, 0, 3)], - edges=[(to_label(cgraph, 1, chunk_offset, 0, 0, 0), to_label(cgraph, 1, chunk_offset+1, 0, 0, 0), inf), - (to_label(cgraph, 1, chunk_offset, 0, 0, 1), to_label(cgraph, 1, chunk_offset+1, 0, 0, 1), inf), - (to_label(cgraph, 1, chunk_offset, 0, 0, 0), to_label(cgraph, 1, chunk_offset, 0, 0, 2), .5)], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[ + to_label(cg, 1, 0, 0, 0, 0), + to_label(cg, 1, 0, 0, 0, 1), + ], + edges=[ + ( + to_label(cg, 1, 0, 0, 0, 1), + to_label(cg, 1, 1, 0, 0, 0), + inf, + ), + ( + to_label(cg, 1, 0, 0, 0, 0), + to_label(cg, 1, 0, 0, 0, 1), + inf, + ), + ], + timestamp=fake_timestamp, + ) # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, chunk_offset+1, 0, 0, 0), to_label(cgraph, 1, chunk_offset+1, 0, 0, 1)], - edges=[(to_label(cgraph, 1, chunk_offset+1, 0, 0, 0), to_label(cgraph, 1, chunk_offset, 0, 0, 0), inf), - (to_label(cgraph, 1, chunk_offset+1, 0, 0, 1), to_label(cgraph, 1, chunk_offset, 0, 0, 1), inf), - (to_label(cgraph, 1, chunk_offset+1, 0, 0, 0), to_label(cgraph, 1, chunk_offset+2, 0, 0, 0), inf), - (to_label(cgraph, 1, chunk_offset+1, 0, 0, 1), to_label(cgraph, 1, chunk_offset+2, 0, 0, 1), inf)], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[ + to_label(cg, 1, 1, 0, 0, 0), + to_label(cg, 1, 1, 0, 0, 1), + ], + edges=[ + ( + to_label(cg, 1, 1, 0, 0, 0), + to_label(cg, 1, 0, 0, 0, 1), + inf, + ), + ( + to_label(cg, 1, 1, 0, 0, 0), + to_label(cg, 1, 1, 0, 0, 1), + inf, + ), + ], + timestamp=fake_timestamp, + ) # Preparation: Build Chunk C - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, chunk_offset+2, 0, 0, 0), to_label(cgraph, 1, chunk_offset+2, 0, 0, 1)], - edges=[(to_label(cgraph, 1, chunk_offset+2, 0, 0, 0), to_label(cgraph, 1, chunk_offset+1, 0, 0, 0), inf), - (to_label(cgraph, 1, chunk_offset+2, 0, 0, 1), to_label(cgraph, 1, chunk_offset+1, 0, 0, 1), inf), - (to_label(cgraph, 1, chunk_offset+2, 0, 0, 0), to_label(cgraph, 1, chunk_offset+3, 0, 0, 0), inf), - (to_label(cgraph, 1, chunk_offset+2, 0, 0, 0), to_label(cgraph, 1, chunk_offset+2, 0, 0, 1), .5)], - timestamp=fake_timestamp) - - # Preparation: Build Chunk D - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, chunk_offset+3, 0, 0, 0)], - edges=[(to_label(cgraph, 1, chunk_offset+3, 0, 0, 0), to_label(cgraph, 1, chunk_offset+2, 0, 0, 0), inf), - (to_label(cgraph, 1, chunk_offset+3, 0, 0, 0), to_label(cgraph, 1, chunk_offset+4, 0, 0, 0), inf)], - timestamp=fake_timestamp) - - # Preparation: Build Chunk E - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, chunk_offset+4, 0, 0, 0)], - edges=[(to_label(cgraph, 1, chunk_offset+4, 0, 0, 0), to_label(cgraph, 1, chunk_offset+3, 0, 0, 0), inf), - (to_label(cgraph, 1, chunk_offset+4, 0, 0, 0), to_label(cgraph, 1, chunk_offset+5, 0, 0, 0), inf)], - timestamp=fake_timestamp) - - # Preparation: Build Chunk F - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, chunk_offset+5, 0, 0, 0)], - edges=[(to_label(cgraph, 1, chunk_offset+5, 0, 0, 0), to_label(cgraph, 1, chunk_offset+4, 0, 0, 0), inf)], - timestamp=fake_timestamp) - - - for i_layer in range(3, 7): - for i_chunk in range(0, 2 ** (7 - i_layer), 2): - cgraph.add_layer(i_layer, np.array([[i_chunk, 0, 0], [i_chunk+1, 0, 0]]), time_stamp=fake_timestamp, n_threads=1) - - new_roots = cgraph.add_edges("Jane Doe", - [to_label(cgraph, 1, chunk_offset, 0, 0, 0), - to_label(cgraph, 1, chunk_offset, 0, 0, 3)], - affinities=.9).new_root_ids - - assert len(new_roots) == 1 - root_id = new_roots[0] - - cross_edge_dict_layers = graph_tests.root_cross_edge_test(root_id, cg=cgraph) # dict: layer -> cross_edge_dict - n_cross_edges_layer = collections.defaultdict(list) + create_chunk( + cg, + vertices=[ + to_label(cg, 1, 2, 0, 0, 0), + ], + timestamp=fake_timestamp, + ) - for child_layer in cross_edge_dict_layers.keys(): - for layer in cross_edge_dict_layers[child_layer].keys(): - if layer < child_layer: - continue + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + add_layer( + cg, + 3, + [1, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + add_layer( + cg, + 4, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + add_layer( + cg, + 5, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) - n_cross_edges_layer[layer].append( - len(cross_edge_dict_layers[child_layer][layer])) + new_roots = cg.add_edges( + "Jane Doe", + [ + to_label(cg, 1, 1, 0, 0, 0), + to_label(cg, 1, 2, 0, 0, 0), + ], + affinities=0.9, + ).new_root_ids - for layer in n_cross_edges_layer.keys(): - cgraph.logger.debug("LAYER %d" % layer) - assert len(np.unique(n_cross_edges_layer[layer])) == 1 + assert len(new_roots) == 1 -class TestGraphSplit: - @pytest.mark.timeout(30) - def test_split_pair_same_chunk(self, gen_graph): +class TestGraphMergeSplit: + @pytest.mark.timeout(240) + def test_multiple_cuts_and_splits(self, gen_graph_simplequerytest): """ - Remove edge between existing RG supervoxels 1 and 2 (same chunk) - Expected: Different (new) parents for RG 1 and 2 on Layer two - ┌─────┐ ┌─────┐ - │ A¹ │ │ A¹ │ - │ 1━2 │ => │ 1 2 │ - │ │ │ │ - └─────┘ └─────┘ + ┌─────┬─────┬─────┐ L X Y Z S L X Y Z S L X Y Z S L X Y Z S + │ A¹ │ B¹ │ C¹ │ 1: 1 0 0 0 0 ─── 2 0 0 0 1 ───────────────── 4 0 0 0 1 + │ 1 │ 3━2━┿━━4 │ 2: 1 1 0 0 0 ─┬─ 2 1 0 0 1 ─── 3 0 0 0 1 ─┬─ 4 0 0 0 2 + │ │ │ │ 3: 1 1 0 0 1 ─┘ │ + └─────┴─────┴─────┘ 4: 1 2 0 0 0 ─── 2 2 0 0 1 ─── 3 1 0 0 1 ─┘ """ + cg = gen_graph_simplequerytest - cgraph = gen_graph(n_layers=2) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1), 0.5)], - timestamp=fake_timestamp) + rr = cg.range_read_chunk(chunk_id=cg.get_chunk_id(layer=4, x=0, y=0, z=0)) + root_ids_t0 = list(rr.keys()) + child_ids = [types.empty_1d] + for root_id in root_ids_t0: + child_ids.append(cg.get_subgraph([root_id], leaves_only=True)) + child_ids = np.concatenate(child_ids) - # Split - new_root_ids = cgraph.remove_edges("Jane Doe", to_label(cgraph, 1, 0, 0, 0, 1), to_label(cgraph, 1, 0, 0, 0, 0), mincut=False).new_root_ids + for i in range(10): - # Check New State - assert len(new_root_ids) == 2 - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)) != cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 1)) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 0)) - assert len(partners) == 0 - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 1)) - assert len(partners) == 0 - leaves = np.unique(cgraph.get_subgraph_nodes(cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)))) - assert len(leaves) == 1 and to_label(cgraph, 1, 0, 0, 0, 0) in leaves - leaves = np.unique(cgraph.get_subgraph_nodes(cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 1)))) - assert len(leaves) == 1 and to_label(cgraph, 1, 0, 0, 0, 1) in leaves - - # Check Old State still accessible - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0), time_stamp=fake_timestamp) == \ - cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 1), time_stamp=fake_timestamp) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 0), time_stamp=fake_timestamp) - assert len(partners) == 1 and partners[0] == to_label(cgraph, 1, 0, 0, 0, 1) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 1), time_stamp=fake_timestamp) - assert len(partners) == 1 and partners[0] == to_label(cgraph, 1, 0, 0, 0, 0) - leaves = np.unique(cgraph.get_subgraph_nodes(cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0), time_stamp=fake_timestamp))) - assert len(leaves) == 2 - assert to_label(cgraph, 1, 0, 0, 0, 0) in leaves - assert to_label(cgraph, 1, 0, 0, 0, 1) in leaves + print(f"\n\nITERATION {i}/10") + print("\n\nMERGE 1 & 3\n\n") + new_roots = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1)], + affinities=0.9, + ).new_root_ids + assert len(new_roots) == 1 + assert len(cg.get_subgraph([new_roots[0]], leaves_only=True)) == 4 - # assert len(cgraph.get_latest_roots()) == 2 - # assert len(cgraph.get_latest_roots(fake_timestamp)) == 1 + root_ids = [] + for child_id in child_ids: + root_ids.append(cg.get_root(child_id)) - def test_split_nonexisting_edge(self, gen_graph): - """ - Remove edge between existing RG supervoxels 1 and 2 (same chunk) - Expected: Different (new) parents for RG 1 and 2 on Layer two - ┌─────┐ ┌─────┐ - │ A¹ │ │ A¹ │ - │ 1━2 │ => │ 1━2 │ - │ | │ │ | │ - │ 3 │ │ 3 │ - └─────┘ └─────┘ - """ + u_root_ids = np.unique(root_ids) + assert len(u_root_ids) == 1 - cgraph = gen_graph(n_layers=2) + # ------------------------------------------------------------------ + new_roots = cg.remove_edges( + "John Doe", + source_ids=to_label(cg, 1, 1, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 1), + mincut=False, + ).new_root_ids - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1), 0.5), - (to_label(cgraph, 1, 0, 0, 0, 2), to_label(cgraph, 1, 0, 0, 0, 1), 0.5)], - timestamp=fake_timestamp) + assert len(np.unique(new_roots)) == 2 - # Split - new_root_ids = cgraph.remove_edges("Jane Doe", to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 2), mincut=False).new_root_ids + root_ids = [] + for child_id in child_ids: + root_ids.append(cg.get_root(child_id)) - assert len(new_root_ids) == 1 - assert len(cgraph.get_atomic_node_partners(to_label(cgraph, 1, 0, 0, 0, 0))) == 1 + u_root_ids = np.unique(root_ids) + these_child_ids = [] + for root_id in u_root_ids: + these_child_ids.extend(cg.get_subgraph([root_id], leaves_only=True)) + assert len(these_child_ids) == 4 + assert len(u_root_ids) == 2 - @pytest.mark.timeout(30) - def test_split_pair_neighboring_chunks(self, gen_graph): - """ - Remove edge between existing RG supervoxels 1 and 2 (neighboring chunks) - ┌─────┬─────┐ ┌─────┬─────┐ - │ A¹ │ B¹ │ │ A¹ │ B¹ │ - │ 1━━┿━━2 │ => │ 1 │ 2 │ - │ │ │ │ │ │ - └─────┴─────┘ └─────┴─────┘ - """ + # ------------------------------------------------------------------ - cgraph = gen_graph(n_layers=3) + new_roots = cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 1), + mincut=False, + ).new_root_ids + assert len(new_roots) == 2 - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 0), 1.0)], - timestamp=fake_timestamp) + root_ids = [] + for child_id in child_ids: + root_ids.append(cg.get_root(child_id)) - # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0), 1.0)], - timestamp=fake_timestamp) + u_root_ids = np.unique(root_ids) + assert len(u_root_ids) == 3 - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), time_stamp=fake_timestamp, n_threads=1) + # ------------------------------------------------------------------ - # Split - new_root_ids = cgraph.remove_edges("Jane Doe", to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0), mincut=False).new_root_ids + print(f"\n\nITERATION {i}/10") + print("\n\nMERGE 2 & 3\n\n") - # Check New State - assert len(new_root_ids) == 2 - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)) != cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0)) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 0)) - assert len(partners) == 0 - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 1, 0, 0, 0)) - assert len(partners) == 0 - leaves = np.unique(cgraph.get_subgraph_nodes(cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)))) - assert len(leaves) == 1 and to_label(cgraph, 1, 0, 0, 0, 0) in leaves - leaves = np.unique(cgraph.get_subgraph_nodes(cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0)))) - assert len(leaves) == 1 and to_label(cgraph, 1, 1, 0, 0, 0) in leaves - - # Check Old State still accessible - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0), time_stamp=fake_timestamp) == \ - cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0), time_stamp=fake_timestamp) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 0), time_stamp=fake_timestamp) - assert len(partners) == 1 and partners[0] == to_label(cgraph, 1, 1, 0, 0, 0) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 1, 0, 0, 0), time_stamp=fake_timestamp) - assert len(partners) == 1 and partners[0] == to_label(cgraph, 1, 0, 0, 0, 0) - leaves = np.unique(cgraph.get_subgraph_nodes(cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0), time_stamp=fake_timestamp))) - assert len(leaves) == 2 - assert to_label(cgraph, 1, 0, 0, 0, 0) in leaves - assert to_label(cgraph, 1, 1, 0, 0, 0) in leaves + new_roots = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1)], + affinities=0.9, + ).new_root_ids + assert len(new_roots) == 1 - assert len(cgraph.get_latest_roots()) == 2 - assert len(cgraph.get_latest_roots(fake_timestamp)) == 1 + root_ids = [] + for child_id in child_ids: + root_ids.append(cg.get_root(child_id)) - @pytest.mark.timeout(30) - def test_split_verify_cross_chunk_edges(self, gen_graph): - """ - Remove edge between existing RG supervoxels 1 and 2 (neighboring chunks) - ┌─────┬─────┬─────┐ ┌─────┬─────┬─────┐ - | │ A¹ │ B¹ │ | │ A¹ │ B¹ │ - | │ 1━━┿━━3 │ => | │ 1━━┿━━3 │ - | │ | │ │ | │ │ │ - | │ 2 │ │ | │ 2 │ │ - └─────┴─────┴─────┘ └─────┴─────┴─────┘ - """ - - cgraph = gen_graph(n_layers=4) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 1)], - edges=[(to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 2, 0, 0, 0), inf), - (to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 1), .5)], - timestamp=fake_timestamp) - - # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 2, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 2, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 0), inf)], - timestamp=fake_timestamp) - - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(3, np.array([[2, 0, 0], [3, 0, 0]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(4, np.array([[0, 0, 0], [1, 0, 0]]), time_stamp=fake_timestamp, n_threads=1) - - assert cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0)) == cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 1)) - assert cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0)) == cgraph.get_root(to_label(cgraph, 1, 2, 0, 0, 0)) - - # Split - new_root_ids = cgraph.remove_edges("Jane Doe", to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 1), mincut=False).new_root_ids - - assert len(new_root_ids) == 2 - - svs2 = cgraph.get_subgraph_nodes(new_root_ids[0]) - svs1 = cgraph.get_subgraph_nodes(new_root_ids[1]) - len_set = {1, 2} - assert len(svs1) in len_set - len_set.remove(len(svs1)) - assert len(svs2) in len_set - - - # Check New State - assert len(new_root_ids) == 2 - assert cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0)) != cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 1)) - assert cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0)) == cgraph.get_root(to_label(cgraph, 1, 2, 0, 0, 0)) - - cc_dict = cgraph.get_atomic_cross_edge_dict(cgraph.get_parent(to_label(cgraph, 1, 1, 0, 0, 0))) - assert len(cc_dict[3]) == 1 - assert cc_dict[3][0][0] == to_label(cgraph, 1, 1, 0, 0, 0) - assert cc_dict[3][0][1] == to_label(cgraph, 1, 2, 0, 0, 0) - - assert len(cgraph.get_latest_roots()) == 2 - assert len(cgraph.get_latest_roots(fake_timestamp)) == 1 - - @pytest.mark.timeout(30) - def test_split_verify_loop(self, gen_graph): - """ - Remove edge between existing RG supervoxels 1 and 2 (neighboring chunks) - ┌─────┬────────┬─────┐ ┌─────┬────────┬─────┐ - | │ A¹ │ B¹ │ | │ A¹ │ B¹ │ - | │ 4━━1━━┿━━5 │ => | │ 4 1━━┿━━5 │ - | │ / │ | │ | │ │ | │ - | │ 3 2━━┿━━6 │ | │ 3 2━━┿━━6 │ - └─────┴────────┴─────┘ └─────┴────────┴─────┘ - """ - - cgraph = gen_graph(n_layers=4) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 1), - to_label(cgraph, 1, 1, 0, 0, 2), to_label(cgraph, 1, 1, 0, 0, 3)], - edges=[(to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 2, 0, 0, 0), inf), - (to_label(cgraph, 1, 1, 0, 0, 1), to_label(cgraph, 1, 2, 0, 0, 1), inf), - (to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 2), .5), - (to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 3), .5)], - timestamp=fake_timestamp) - - # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 2, 0, 0, 0), to_label(cgraph, 1, 2, 0, 0, 1)], - edges=[(to_label(cgraph, 1, 2, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 0), inf), - (to_label(cgraph, 1, 2, 0, 0, 1), to_label(cgraph, 1, 1, 0, 0, 1), inf), - (to_label(cgraph, 1, 2, 0, 0, 1), to_label(cgraph, 1, 2, 0, 0, 0), .5)], - timestamp=fake_timestamp) - - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(3, np.array([[2, 0, 0], [3, 0, 0]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(4, np.array([[0, 0, 0], [1, 0, 0]]), time_stamp=fake_timestamp, n_threads=1) - - assert cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0)) == cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 1)) - assert cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0)) == cgraph.get_root(to_label(cgraph, 1, 2, 0, 0, 0)) - - # Split - new_root_ids = cgraph.remove_edges("Jane Doe", to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 2), mincut=False).new_root_ids - - assert len(new_root_ids) == 2 - - new_root_ids = cgraph.remove_edges("Jane Doe", to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 3), mincut=False).new_root_ids - - assert len(new_root_ids) == 2 - - cc_dict = cgraph.get_atomic_cross_edge_dict(cgraph.get_parent(to_label(cgraph, 1, 1, 0, 0, 0))) - assert len(cc_dict[3]) == 1 - cc_dict = cgraph.get_atomic_cross_edge_dict(cgraph.get_parent(to_label(cgraph, 1, 1, 0, 0, 0))) - assert len(cc_dict[3]) == 1 - - assert len(cgraph.get_latest_roots()) == 3 - assert len(cgraph.get_latest_roots(fake_timestamp)) == 1 - - @pytest.mark.timeout(30) - def test_split_pair_disconnected_chunks(self, gen_graph): - """ - Remove edge between existing RG supervoxels 1 and 2 (disconnected chunks) - ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ - │ A¹ │ ... │ Z¹ │ │ A¹ │ ... │ Z¹ │ - │ 1━━┿━━━━━┿━━2 │ => │ 1 │ │ 2 │ - │ │ │ │ │ │ │ │ - └─────┘ └─────┘ └─────┘ └─────┘ - """ - - cgraph = gen_graph(n_layers=9) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 127, 127, 127, 0), 1.0)], - timestamp=fake_timestamp) - - # Preparation: Build Chunk Z - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 127, 127, 127, 0)], - edges=[(to_label(cgraph, 1, 127, 127, 127, 0), to_label(cgraph, 1, 0, 0, 0, 0), 1.0)], - timestamp=fake_timestamp) - - cgraph.add_layer(3, np.array([[0x00, 0x00, 0x00]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(3, np.array([[0x7F, 0x7F, 0x7F]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(4, np.array([[0x00, 0x00, 0x00]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(4, np.array([[0x3F, 0x3F, 0x3F]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(5, np.array([[0x00, 0x00, 0x00]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(5, np.array([[0x1F, 0x1F, 0x1F]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(6, np.array([[0x00, 0x00, 0x00]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(6, np.array([[0x0F, 0x0F, 0x0F]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(7, np.array([[0x00, 0x00, 0x00]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(7, np.array([[0x07, 0x07, 0x07]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(8, np.array([[0x00, 0x00, 0x00]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(8, np.array([[0x03, 0x03, 0x03]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(9, np.array([[0x00, 0x00, 0x00], [0x01, 0x01, 0x01]]), time_stamp=fake_timestamp, n_threads=1) - - # Split - new_roots = cgraph.remove_edges("Jane Doe", to_label(cgraph, 1, 127, 127, 127, 0), to_label(cgraph, 1, 0, 0, 0, 0), mincut=False).new_root_ids - - # Check New State - assert len(new_roots) == 2 - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)) != cgraph.get_root(to_label(cgraph, 1, 127, 127, 127, 0)) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 0)) - assert len(partners) == 0 - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 127, 127, 127, 0)) - assert len(partners) == 0 - leaves = np.unique(cgraph.get_subgraph_nodes(cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)))) - assert len(leaves) == 1 and to_label(cgraph, 1, 0, 0, 0, 0) in leaves - leaves = np.unique(cgraph.get_subgraph_nodes(cgraph.get_root(to_label(cgraph, 1, 127, 127, 127, 0)))) - assert len(leaves) == 1 and to_label(cgraph, 1, 127, 127, 127, 0) in leaves - - # Check Old State still accessible - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0), time_stamp=fake_timestamp) == \ - cgraph.get_root(to_label(cgraph, 1, 127, 127, 127, 0), time_stamp=fake_timestamp) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 0), time_stamp=fake_timestamp) - assert len(partners) == 1 and partners[0] == to_label(cgraph, 1, 127, 127, 127, 0) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 127, 127, 127, 0), time_stamp=fake_timestamp) - assert len(partners) == 1 and partners[0] == to_label(cgraph, 1, 0, 0, 0, 0) - leaves = np.unique(cgraph.get_subgraph_nodes(cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0), time_stamp=fake_timestamp))) - assert len(leaves) == 2 - assert to_label(cgraph, 1, 0, 0, 0, 0) in leaves - assert to_label(cgraph, 1, 127, 127, 127, 0) in leaves - - @pytest.mark.timeout(30) - def test_split_pair_already_disconnected(self, gen_graph): - """ - Try to remove edge between already disconnected RG supervoxels 1 and 2 (same chunk). - Expected: No change, no error - ┌─────┐ ┌─────┐ - │ A¹ │ │ A¹ │ - │ 1 2 │ => │ 1 2 │ - │ │ │ │ - └─────┘ └─────┘ - """ - - cgraph = gen_graph(n_layers=2) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1)], - edges=[], - timestamp=fake_timestamp) - - res_old = cgraph.table.read_rows() - res_old.consume_all() - - # Split - with pytest.raises(cg_exceptions.PreconditionError): - cgraph.remove_edges("Jane Doe", to_label(cgraph, 1, 0, 0, 0, 1), to_label(cgraph, 1, 0, 0, 0, 0), mincut=False) - - res_new = cgraph.table.read_rows() - res_new.consume_all() - - # Check - if res_old.rows != res_new.rows: - warn("Rows were modified when splitting a pair of already disconnected supervoxels. " - "While probably not an error, it is an unnecessary operation.") - - @pytest.mark.timeout(30) - def test_split_full_circle_to_triple_chain_same_chunk(self, gen_graph): - """ - Remove direct edge between RG supervoxels 1 and 2, but leave indirect connection (same chunk) - ┌─────┐ ┌─────┐ - │ A¹ │ │ A¹ │ - │ 1━2 │ => │ 1 2 │ - │ ┗3┛ │ │ ┗3┛ │ - └─────┘ └─────┘ - """ - - cgraph = gen_graph(n_layers=2) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1), to_label(cgraph, 1, 0, 0, 0, 2)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 2), 0.5), - (to_label(cgraph, 1, 0, 0, 0, 1), to_label(cgraph, 1, 0, 0, 0, 2), 0.5), - (to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1), 0.3)], - timestamp=fake_timestamp) - - # Split - new_root_ids = cgraph.remove_edges("Jane Doe", to_label(cgraph, 1, 0, 0, 0, 1), to_label(cgraph, 1, 0, 0, 0, 0), mincut=False).new_root_ids - - # Check New State - assert len(new_root_ids) == 1 - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)) == new_root_ids[0] - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 1)) == new_root_ids[0] - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 2)) == new_root_ids[0] - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 0)) - assert len(partners) == 1 and partners[0] == to_label(cgraph, 1, 0, 0, 0, 2) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 1)) - assert len(partners) == 1 and partners[0] == to_label(cgraph, 1, 0, 0, 0, 2) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 2)) - assert len(partners) == 2 - assert to_label(cgraph, 1, 0, 0, 0, 0) in partners - assert to_label(cgraph, 1, 0, 0, 0, 1) in partners - leaves = np.unique(cgraph.get_subgraph_nodes(new_root_ids[0])) - assert len(leaves) == 3 - assert to_label(cgraph, 1, 0, 0, 0, 0) in leaves - assert to_label(cgraph, 1, 0, 0, 0, 1) in leaves - assert to_label(cgraph, 1, 0, 0, 0, 2) in leaves - - # Check Old State still accessible - old_root_id = cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0), time_stamp=fake_timestamp) - assert new_root_ids[0] != old_root_id - - # assert len(cgraph.get_latest_roots()) == 1 - # assert len(cgraph.get_latest_roots(fake_timestamp)) == 1 - - @pytest.mark.timeout(30) - def test_split_full_circle_to_triple_chain_neighboring_chunks(self, gen_graph): - """ - Remove direct edge between RG supervoxels 1 and 2, but leave indirect connection (neighboring chunks) - ┌─────┬─────┐ ┌─────┬─────┐ - │ A¹ │ B¹ │ │ A¹ │ B¹ │ - │ 1━━┿━━2 │ => │ 1 │ 2 │ - │ ┗3━┿━━┛ │ │ ┗3━┿━━┛ │ - └─────┴─────┘ └─────┴─────┘ - """ - - cgraph = gen_graph(n_layers=3) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1), 0.5), - (to_label(cgraph, 1, 0, 0, 0, 1), to_label(cgraph, 1, 1, 0, 0, 0), 0.5), - (to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 0), 0.3)], - timestamp=fake_timestamp) - - # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1), 0.5), - (to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0), 0.3)], - timestamp=fake_timestamp) - - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), time_stamp=fake_timestamp, n_threads=1) - - # Split - new_root_ids = cgraph.remove_edges("Jane Doe", to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0), mincut=False).new_root_ids - - # Check New State - assert len(new_root_ids) == 1 - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)) == new_root_ids[0] - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 1)) == new_root_ids[0] - assert cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0)) == new_root_ids[0] - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 0)) - assert len(partners) == 1 and partners[0] == to_label(cgraph, 1, 0, 0, 0, 1) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 1, 0, 0, 0)) - assert len(partners) == 1 and partners[0] == to_label(cgraph, 1, 0, 0, 0, 1) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 1)) - assert len(partners) == 2 - assert to_label(cgraph, 1, 0, 0, 0, 0) in partners - assert to_label(cgraph, 1, 1, 0, 0, 0) in partners - leaves = np.unique(cgraph.get_subgraph_nodes(new_root_ids[0])) - assert len(leaves) == 3 - assert to_label(cgraph, 1, 0, 0, 0, 0) in leaves - assert to_label(cgraph, 1, 0, 0, 0, 1) in leaves - assert to_label(cgraph, 1, 1, 0, 0, 0) in leaves - - # Check Old State still accessible - old_root_id = cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0), time_stamp=fake_timestamp) - assert new_root_ids[0] != old_root_id - - assert len(cgraph.get_latest_roots()) == 1 - assert len(cgraph.get_latest_roots(fake_timestamp)) == 1 - - @pytest.mark.timeout(30) - def test_split_full_circle_to_triple_chain_disconnected_chunks(self, gen_graph): - """ - Remove direct edge between RG supervoxels 1 and 2, but leave indirect connection (disconnected chunks) - ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ - │ A¹ │ ... │ Z¹ │ │ A¹ │ ... │ Z¹ │ - │ 1━━┿━━━━━┿━━2 │ => │ 1 │ │ 2 │ - │ ┗3━┿━━━━━┿━━┛ │ │ ┗3━┿━━━━━┿━━┛ │ - └─────┘ └─────┘ └─────┘ └─────┘ - """ - - cgraph = gen_graph(n_layers=9) - - loc = 2 - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1), 0.5), - (to_label(cgraph, 1, 0, 0, 0, 1), to_label(cgraph, 1, loc, loc, loc, 0), 0.5), - (to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, loc, loc, loc, 0), 0.3)], - timestamp=fake_timestamp) - - # Preparation: Build Chunk Z - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, loc, loc, loc, 0)], - edges=[(to_label(cgraph, 1, loc, loc, loc, 0), to_label(cgraph, 1, 0, 0, 0, 1), 0.5), - (to_label(cgraph, 1, loc, loc, loc, 0), to_label(cgraph, 1, 0, 0, 0, 0), 0.3)], - timestamp=fake_timestamp) - - for i_layer in range(3, 10): - if loc // 2**(i_layer - 3) == 1: - cgraph.add_layer(i_layer, np.array([[0, 0, 0], [1, 1, 1]]), time_stamp=fake_timestamp, n_threads=1) - elif loc // 2**(i_layer - 3) == 0: - cgraph.add_layer(i_layer, np.array([[0, 0, 0]]), time_stamp=fake_timestamp, n_threads=1) - else: - cgraph.add_layer(i_layer, np.array([[0, 0, 0]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.add_layer(i_layer, np.array([[loc // 2**(i_layer - 3), loc // 2**(i_layer - 3), loc // 2**(i_layer - 3)]]), time_stamp=fake_timestamp, n_threads=1) - - assert cgraph.get_root(to_label(cgraph, 1, loc, loc, loc, 0)) == \ - cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)) == \ - cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 1)) - - # Split - new_root_ids = cgraph.remove_edges("Jane Doe", to_label(cgraph, 1, loc, loc, loc, 0), to_label(cgraph, 1, 0, 0, 0, 0), mincut=False).new_root_ids - - # Check New State - assert len(new_root_ids) == 1 - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)) == new_root_ids[0] - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 1)) == new_root_ids[0] - assert cgraph.get_root(to_label(cgraph, 1, loc, loc, loc, 0)) == new_root_ids[0] - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 0)) - assert len(partners) == 1 and partners[0] == to_label(cgraph, 1, 0, 0, 0, 1) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, loc, loc, loc, 0)) - assert len(partners) == 1 and partners[0] == to_label(cgraph, 1, 0, 0, 0, 1) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 1)) - assert len(partners) == 2 - assert to_label(cgraph, 1, 0, 0, 0, 0) in partners - assert to_label(cgraph, 1, loc, loc, loc, 0) in partners - leaves = np.unique(cgraph.get_subgraph_nodes(new_root_ids[0])) - assert len(leaves) == 3 - assert to_label(cgraph, 1, 0, 0, 0, 0) in leaves - assert to_label(cgraph, 1, 0, 0, 0, 1) in leaves - assert to_label(cgraph, 1, loc, loc, loc, 0) in leaves - - # Check Old State still accessible - old_root_id = cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0), time_stamp=fake_timestamp) - assert new_root_ids[0] != old_root_id - - assert len(cgraph.get_latest_roots()) == 1 - assert len(cgraph.get_latest_roots(fake_timestamp)) == 1 - - @pytest.mark.timeout(30) - def test_split_same_node(self, gen_graph): - """ - Try to remove (non-existing) edge between RG supervoxel 1 and itself - ┌─────┐ - │ A¹ │ - │ 1 │ => Reject - │ │ - └─────┘ - """ - - cgraph = gen_graph(n_layers=2) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp) - - res_old = cgraph.table.read_rows() - res_old.consume_all() - - # Split - with pytest.raises(cg_exceptions.PreconditionError): - cgraph.remove_edges("Jane Doe", to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0), mincut=False) - - res_new = cgraph.table.read_rows() - res_new.consume_all() - - assert res_new.rows == res_old.rows - - @pytest.mark.timeout(30) - def test_split_pair_abstract_nodes(self, gen_graph): - """ - Try to remove (non-existing) edge between RG supervoxel 1 and abstract node "2" - ┌─────┐ - │ B² │ - │ "2" │ - │ │ - └─────┘ - ┌─────┐ => Reject - │ A¹ │ - │ 1 │ - │ │ - └─────┘ - """ - - cgraph = gen_graph(n_layers=3) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp) - - # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp) - - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), time_stamp=fake_timestamp, n_threads=1) - - res_old = cgraph.table.read_rows() - res_old.consume_all() - - # Split - with pytest.raises(cg_exceptions.PreconditionError): - cgraph.remove_edges("Jane Doe", to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 2, 1, 0, 0, 1), mincut=False) - - res_new = cgraph.table.read_rows() - res_new.consume_all() - - assert res_new.rows == res_old.rows - - @pytest.mark.timeout(30) - def test_diagonal_connections(self, gen_graph): - """ - Create graph with edge between RG supervoxels 1 and 2 (same chunk) - and edge between RG supervoxels 1 and 3 (neighboring chunks) - ┌─────┬─────┐ - │ A¹ │ B¹ │ - │ 2━1━┿━━3 │ - │ / │ │ - ┌─────┬─────┐ - │ | │ │ - │ 4━━┿━━5 │ - │ C¹ │ D¹ │ - └─────┴─────┘ - """ - - cgraph = gen_graph(n_layers=3) - - # Chunk A - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1), 0.5), - (to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 0), inf), - (to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 1, 0, 0), inf)]) - - # Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0), inf)]) - - # Chunk C - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 1, 0, 0)], - edges=[(to_label(cgraph, 1, 0, 1, 0, 0), to_label(cgraph, 1, 1, 1, 0, 0), inf), - (to_label(cgraph, 1, 0, 1, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0), inf)]) - - # Chunk D - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 1, 0, 0)], - edges=[(to_label(cgraph, 1, 1, 1, 0, 0), to_label(cgraph, 1, 0, 1, 0, 0), inf)]) - - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]]), n_threads=1) - - rr = cgraph.range_read_chunk(chunk_id=cgraph.get_chunk_id(layer=3, x=0, y=0, z=0)) - root_ids_t0 = list(rr.keys()) - - assert len(root_ids_t0) == 1 - - child_ids = [] - for root_id in root_ids_t0: - cgraph.logger.debug(("root_id", root_id)) - child_ids.extend(cgraph.get_subgraph_nodes(root_id)) - - - - new_roots = cgraph.remove_edges("Jane Doe", - to_label(cgraph, 1, 0, 0, 0, 0), - to_label(cgraph, 1, 0, 0, 0, 1), - mincut=False).new_root_ids - - assert len(new_roots) == 2 - assert cgraph.get_root(to_label(cgraph, 1, 1, 1, 0, 0)) == \ - cgraph.get_root(to_label(cgraph, 1, 0, 1, 0, 0)) - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)) == \ - cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)) - - # @pytest.mark.timeout(30) - # def test_shatter(self, gen_graph): - # """ - # Create graph with edge between RG supervoxels 1 and 2 (same chunk) - # and edge between RG supervoxels 1 and 3 (neighboring chunks) - # ┌─────┬─────┐ - # │ A¹ │ B¹ │ - # │ 2━1━┿━━3 │ - # │ / │ │ - # ┌─────┬─────┐ - # │ | │ │ - # │ 4━━┿━━5 │ - # │ C¹ │ D¹ │ - # └─────┴─────┘ - # """ - # - # cgraph = gen_graph(n_layers=3) - # - # # Chunk A - # create_chunk(cgraph, - # vertices=[to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1)], - # edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1), 0.5), - # (to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 0), inf), - # (to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 1, 0, 0), inf)]) - # - # # Chunk B - # create_chunk(cgraph, - # vertices=[to_label(cgraph, 1, 1, 0, 0, 0)], - # edges=[(to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0), inf)]) - # - # # Chunk C - # create_chunk(cgraph, - # vertices=[to_label(cgraph, 1, 0, 1, 0, 0)], - # edges=[(to_label(cgraph, 1, 0, 1, 0, 0), to_label(cgraph, 1, 1, 1, 0, 0), .1), - # (to_label(cgraph, 1, 0, 1, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0), inf)]) - # - # # Chunk D - # create_chunk(cgraph, - # vertices=[to_label(cgraph, 1, 1, 1, 0, 0)], - # edges=[(to_label(cgraph, 1, 1, 1, 0, 0), to_label(cgraph, 1, 0, 1, 0, 0), .1)]) - # - # cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]]), n_threads=1) - # - # new_root_ids = cgraph.shatter_nodes("Jane Doe", atomic_node_ids=[to_label(cgraph, 1, 0, 0, 0, 0)]) - # - # cgraph.logger.debug(new_root_ids) - # - # assert len(new_root_ids) == 3 - - - -class TestGraphMergeSplit: - @pytest.mark.timeout(30) - def test_multiple_cuts_and_splits(self, gen_graph_simplequerytest): - """ - ┌─────┬─────┬─────┐ L X Y Z S L X Y Z S L X Y Z S L X Y Z S - │ A¹ │ B¹ │ C¹ │ 1: 1 0 0 0 0 ─── 2 0 0 0 1 ───────────────── 4 0 0 0 1 - │ 1 │ 3━2━┿━━4 │ 2: 1 1 0 0 0 ─┬─ 2 1 0 0 1 ─── 3 0 0 0 1 ─┬─ 4 0 0 0 2 - │ │ │ │ 3: 1 1 0 0 1 ─┘ │ - └─────┴─────┴─────┘ 4: 1 2 0 0 0 ─── 2 2 0 0 1 ─── 3 1 0 0 1 ─┘ - """ - cgraph = gen_graph_simplequerytest - - rr = cgraph.range_read_chunk(chunk_id=cgraph.get_chunk_id(layer=4, x=0, y=0, z=0)) - root_ids_t0 = list(rr.keys()) - child_ids = [] - for root_id in root_ids_t0: - cgraph.logger.debug(f"root_id {root_id}") - child_ids.extend(cgraph.get_subgraph_nodes(root_id)) - - for i in range(10): - cgraph.logger.debug(f"\n\nITERATION {i}/10") - - print(f"\n\nITERATION {i}/10") - print("\n\nMERGE 1 & 3\n\n") - cgraph.logger.debug("\n\nMERGE 1 & 3\n\n") - new_roots = cgraph.add_edges("Jane Doe", - [to_label(cgraph, 1, 0, 0, 0, 0), - to_label(cgraph, 1, 1, 0, 0, 1)], - affinities=.9).new_root_ids - assert len(new_roots) == 1 - - subgraph_dict = cgraph.get_subgraph_nodes(new_roots[0], return_layers=[3, 2, 1]) - - assert len(cgraph.get_subgraph_nodes(new_roots[0])) == 4 - - root_ids = [] - for child_id in child_ids: - root_ids.append(cgraph.get_root(child_id)) - cgraph.logger.debug((child_id, cgraph.get_chunk_coordinates(child_id), root_ids[-1])) - - parent_id = cgraph.get_parent(child_id) - cgraph.logger.debug((parent_id, cgraph.read_cross_chunk_edges(parent_id))) - - u_root_ids = np.unique(root_ids) - assert len(u_root_ids) == 1 - - # ------------------------------------------------------------------ - - - cgraph.logger.debug("\n\nSPLIT 2 & 3\n\n") - - new_roots = cgraph.remove_edges("John Doe", to_label(cgraph, 1, 1, 0, 0, 0), - to_label(cgraph, 1, 1, 0, 0, 1), mincut=False).new_root_ids - - assert len(np.unique(new_roots)) == 2 - - for root in new_roots: - cgraph.logger.debug(("SUBGRAPH", cgraph.get_subgraph_nodes(root))) - - cgraph.logger.debug("test children") - root_ids = [] - for child_id in child_ids: - root_ids.append(cgraph.get_root(child_id)) - cgraph.logger.debug((child_id, cgraph.get_chunk_coordinates(child_id), cgraph.get_segment_id(child_id), root_ids[-1])) - cgraph.logger.debug((cgraph.get_atomic_node_info(child_id))) - - cgraph.logger.debug("test root") u_root_ids = np.unique(root_ids) - these_child_ids = [] - for root_id in u_root_ids: - these_child_ids.extend(cgraph.get_subgraph_nodes(root_id, verbose=False)) - cgraph.logger.debug((root_id, cgraph.get_subgraph_nodes(root_id, verbose=False))) - - assert len(these_child_ids) == 4 assert len(u_root_ids) == 2 - # ------------------------------------------------------------------ + # for root_id in root_ids: + # cross_edge_dict_layers = graph_tests.root_cross_edge_test( + # root_id, cg=cg + # ) # dict: layer -> cross_edge_dict + # n_cross_edges_layer = collections.defaultdict(list) - cgraph.logger.debug("\n\nSPLIT 1 & 3\n\n") - new_roots = cgraph.remove_edges("Jane Doe", - to_label(cgraph, 1, 0, 0, 0, 0), - to_label(cgraph, 1, 1, 0, 0, 1), - mincut=False).new_root_ids - assert len(new_roots) == 2 + # for child_layer in cross_edge_dict_layers.keys(): + # for layer in cross_edge_dict_layers[child_layer].keys(): + # n_cross_edges_layer[layer].append( + # len(cross_edge_dict_layers[child_layer][layer]) + # ) - root_ids = [] - for child_id in child_ids: - root_ids.append(cgraph.get_root(child_id)) - cgraph.logger.debug((child_id, cgraph.get_chunk_coordinates(child_id), root_ids[-1])) - - parent_id = cgraph.get_parent(child_id) - cgraph.logger.debug((parent_id, cgraph.read_cross_chunk_edges(parent_id))) - - u_root_ids = np.unique(root_ids) - assert len(u_root_ids) == 3 - - - # ------------------------------------------------------------------ - - cgraph.logger.debug("\n\nMERGE 2 & 3\n\n") - - print(f"\n\nITERATION {i}/10") - print("\n\nMERGE 2 & 3\n\n") - - new_roots = cgraph.add_edges("Jane Doe", - [to_label(cgraph, 1, 1, 0, 0, 0), - to_label(cgraph, 1, 1, 0, 0, 1)], - affinities=.9).new_root_ids - assert len(new_roots) == 1 - - root_ids = [] - for child_id in child_ids: - root_ids.append(cgraph.get_root(child_id)) - cgraph.logger.debug((child_id, cgraph.get_chunk_coordinates(child_id), root_ids[-1])) - - parent_id = cgraph.get_parent(child_id) - cgraph.logger.debug((parent_id, cgraph.read_cross_chunk_edges(parent_id))) - - u_root_ids = np.unique(root_ids) - assert len(u_root_ids) == 2 - - - for root_id in root_ids: - cross_edge_dict_layers = graph_tests.root_cross_edge_test(root_id, cg=cgraph) # dict: layer -> cross_edge_dict - n_cross_edges_layer = collections.defaultdict(list) - - for child_layer in cross_edge_dict_layers.keys(): - for layer in cross_edge_dict_layers[child_layer].keys(): - n_cross_edges_layer[layer].append(len(cross_edge_dict_layers[child_layer][layer])) - - for layer in n_cross_edges_layer.keys(): - assert len(np.unique(n_cross_edges_layer[layer])) == 1 + # for layer in n_cross_edges_layer.keys(): + # assert len(np.unique(n_cross_edges_layer[layer])) == 1 class TestGraphMinCut: @@ -2301,40 +1572,65 @@ def test_cut_regular_link(self, gen_graph): └─────┴─────┘ """ - cgraph = gen_graph(n_layers=3) + cg = gen_graph(n_layers=3) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 0), 0.5)], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0), 0.5)], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), time_stamp=fake_timestamp, n_threads=1) + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) # Mincut - new_root_ids = cgraph.remove_edges( - "Jane Doe", to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 0), - [0, 0, 0], [2*cgraph.chunk_size[0], 2*cgraph.chunk_size[1], cgraph.chunk_size[2]], - mincut=True).new_root_ids + new_root_ids = cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + source_coords=[0, 0, 0], + sink_coords=[ + 2 * cg.meta.graph_config.CHUNK_SIZE[0], + 2 * cg.meta.graph_config.CHUNK_SIZE[1], + cg.meta.graph_config.CHUNK_SIZE[2], + ], + mincut=True, + disallow_isolating_cut=True, + ).new_root_ids # Check New State assert len(new_root_ids) == 2 - assert cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)) != cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0)) - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 0, 0, 0, 0)) - assert len(partners) == 0 - partners, affinities, areas = cgraph.get_atomic_partners(to_label(cgraph, 1, 1, 0, 0, 0)) - assert len(partners) == 0 - leaves = np.unique(cgraph.get_subgraph_nodes(cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)))) - assert len(leaves) == 1 and to_label(cgraph, 1, 0, 0, 0, 0) in leaves - leaves = np.unique(cgraph.get_subgraph_nodes(cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0)))) - assert len(leaves) == 1 and to_label(cgraph, 1, 1, 0, 0, 0) in leaves + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) != cg.get_root( + to_label(cg, 1, 1, 0, 0, 0) + ) + leaves = np.unique( + cg.get_subgraph( + [cg.get_root(to_label(cg, 1, 0, 0, 0, 0))], leaves_only=True + ) + ) + assert len(leaves) == 1 and to_label(cg, 1, 0, 0, 0, 0) in leaves + leaves = np.unique( + cg.get_subgraph( + [cg.get_root(to_label(cg, 1, 1, 0, 0, 0))], leaves_only=True + ) + ) + assert len(leaves) == 1 and to_label(cg, 1, 1, 0, 0, 0) in leaves @pytest.mark.timeout(30) def test_cut_no_link(self, gen_graph): @@ -2347,34 +1643,52 @@ def test_cut_no_link(self, gen_graph): └─────┴─────┘ """ - cgraph = gen_graph(n_layers=3) + cg = gen_graph(n_layers=3) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), time_stamp=fake_timestamp, n_threads=1) + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) - res_old = cgraph.table.read_rows() + res_old = cg.client._table.read_rows() res_old.consume_all() # Mincut - with pytest.raises(cg_exceptions.PreconditionError): - cgraph.remove_edges( - "Jane Doe", to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 0), - [0, 0, 0], [2*cgraph.chunk_size[0], 2*cgraph.chunk_size[1], cgraph.chunk_size[2]], - mincut=True) - - res_new = cgraph.table.read_rows() + with pytest.raises(exceptions.PreconditionError): + cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + source_coords=[0, 0, 0], + sink_coords=[ + 2 * cg.meta.graph_config.CHUNK_SIZE[0], + 2 * cg.meta.graph_config.CHUNK_SIZE[1], + cg.meta.graph_config.CHUNK_SIZE[2], + ], + mincut=True, + ) + + res_new = cg.client._table.read_rows() res_new.consume_all() assert res_new.rows == res_old.rows @@ -2390,35 +1704,58 @@ def test_cut_old_link(self, gen_graph): └─────┴─────┘ """ - cgraph = gen_graph(n_layers=3) + cg = gen_graph(n_layers=3) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 0), 0.5)], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0), 0.5)], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), time_stamp=fake_timestamp, n_threads=1) - cgraph.remove_edges("John Doe", to_label(cgraph, 1, 1, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 0), mincut=False) + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + cg.remove_edges( + "John Doe", + source_ids=to_label(cg, 1, 1, 0, 0, 0), + sink_ids=to_label(cg, 1, 0, 0, 0, 0), + mincut=False, + ) - res_old = cgraph.table.read_rows() + res_old = cg.client._table.read_rows() res_old.consume_all() # Mincut - with pytest.raises(cg_exceptions.PreconditionError): - cgraph.remove_edges( - "Jane Doe", to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 1, 0, 0, 0), - [0, 0, 0], [2*cgraph.chunk_size[0], 2*cgraph.chunk_size[1], cgraph.chunk_size[2]], - mincut=True) - - res_new = cgraph.table.read_rows() + with pytest.raises(exceptions.PreconditionError): + cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + source_coords=[0, 0, 0], + sink_coords=[ + 2 * cg.meta.graph_config.CHUNK_SIZE[0], + 2 * cg.meta.graph_config.CHUNK_SIZE[1], + cg.meta.graph_config.CHUNK_SIZE[2], + ], + mincut=True, + ) + + res_new = cg.client._table.read_rows() res_new.consume_all() assert res_new.rows == res_old.rows @@ -2435,42 +1772,57 @@ def test_cut_indivisible_link(self, gen_graph): └─────┴─────┘ """ - cgraph = gen_graph(n_layers=3) + cg = gen_graph(n_layers=3) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), - to_label(cgraph, 1, 1, 0, 0, 0), inf)], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf)], + timestamp=fake_timestamp, + ) # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 1, 0, 0, 0), - to_label(cgraph, 1, 0, 0, 0, 0), inf)], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf)], + timestamp=fake_timestamp, + ) - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), - time_stamp=fake_timestamp, n_threads=1) + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) - original_parents_1 = cgraph.get_root( - to_label(cgraph, 1, 0, 0, 0, 0), get_all_parents=True) - original_parents_2 = cgraph.get_root( - to_label(cgraph, 1, 1, 0, 0, 0), get_all_parents=True) + original_parents_1 = cg.get_root( + to_label(cg, 1, 0, 0, 0, 0), get_all_parents=True + ) + original_parents_2 = cg.get_root( + to_label(cg, 1, 1, 0, 0, 0), get_all_parents=True + ) # Mincut - with pytest.raises(cg_exceptions.PostconditionError): - cgraph.remove_edges( - "Jane Doe", to_label(cgraph, 1, 0, 0, 0, 0), - to_label(cgraph, 1, 1, 0, 0, 0), - [0, 0, 0], [2 * cgraph.chunk_size[0], 2 * cgraph.chunk_size[1], - cgraph.chunk_size[2]], - mincut=True) - - new_parents_1 = cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0), get_all_parents=True) - new_parents_2 = cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0), get_all_parents=True) + with pytest.raises(exceptions.PostconditionError): + cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + source_coords=[0, 0, 0], + sink_coords=[ + 2 * cg.meta.graph_config.CHUNK_SIZE[0], + 2 * cg.meta.graph_config.CHUNK_SIZE[1], + cg.meta.graph_config.CHUNK_SIZE[2], + ], + mincut=True, + ) + + new_parents_1 = cg.get_root(to_label(cg, 1, 0, 0, 0, 0), get_all_parents=True) + new_parents_2 = cg.get_root(to_label(cg, 1, 1, 0, 0, 0), get_all_parents=True) assert np.all(np.array(original_parents_1) == np.array(new_parents_1)) assert np.all(np.array(original_parents_2) == np.array(new_parents_2)) @@ -2483,21 +1835,35 @@ def test_mincut_disrespects_sources_or_sinks(self, gen_graph): two sinks, this can happen when an edge along the only path between two sources or two sinks is cut. """ - cgraph = gen_graph(n_layers=2) + cg = gen_graph(n_layers=2) fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1), to_label(cgraph, 1, 0, 0, 0, 2), to_label(cgraph, 1, 0, 0, 0, 3)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 2), 2), (to_label(cgraph, 1, 0, 0, 0, 1), to_label(cgraph, 1, 0, 0, 0, 2), 3), (to_label(cgraph, 1, 0, 0, 0, 2), to_label(cgraph, 1, 0, 0, 0, 3), 10)], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[ + to_label(cg, 1, 0, 0, 0, 0), + to_label(cg, 1, 0, 0, 0, 1), + to_label(cg, 1, 0, 0, 0, 2), + to_label(cg, 1, 0, 0, 0, 3), + ], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 2), 2), + (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2), 3), + (to_label(cg, 1, 0, 0, 0, 2), to_label(cg, 1, 0, 0, 0, 3), 10), + ], + timestamp=fake_timestamp, + ) # Mincut - with pytest.raises(cg_exceptions.PreconditionError): - cgraph.remove_edges( - "Jane Doe", [to_label(cgraph, 1, 0, 0, 0, 0), to_label(cgraph, 1, 0, 0, 0, 1)], - [to_label(cgraph, 1, 0, 0, 0, 3)], - [[0, 0, 0], [10,0,0]], [[5,5,0]], - mincut=True) + with pytest.raises(exceptions.PreconditionError): + cg.remove_edges( + "Jane Doe", + source_ids=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + sink_ids=[to_label(cg, 1, 0, 0, 0, 3)], + source_coords=[[0, 0, 0], [10, 0, 0]], + sink_coords=[[5, 5, 0]], + mincut=True, + ) class TestGraphMultiCut: @@ -2505,10 +1871,25 @@ class TestGraphMultiCut: def test_cut_multi_tree(self, gen_graph): pass + @pytest.mark.timeout(30) + def test_path_augmented_multicut(self, sv_data): + sv_edges, sv_sources, sv_sinks, sv_affinity, sv_area = sv_data + edges = Edges( + sv_edges[:, 0], sv_edges[:, 1], affinities=sv_affinity, areas=sv_area + ) + + cut_edges_aug = run_multicut(edges, sv_sources, sv_sinks, path_augment=True) + assert cut_edges_aug.shape[0] == 350 + + with pytest.raises(exceptions.PreconditionError): + run_multicut(edges, sv_sources, sv_sinks, path_augment=False) + pass + class TestGraphHistory: - """ These test inadvertantly also test merge and split operations """ - @pytest.mark.timeout(30) + """These test inadvertantly also test merge and split operations""" + + @pytest.mark.timeout(120) def test_cut_merge_history(self, gen_graph): """ Regular link between 1 and 2 @@ -2520,78 +1901,130 @@ def test_cut_merge_history(self, gen_graph): (1) Split 1 and 2 (2) Merge 1 and 2 """ + from ..graph.lineage import lineage_graph - cgraph = gen_graph(n_layers=3) + cg = gen_graph(n_layers=3) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 0, 0, 0, 0), - to_label(cgraph, 1, 1, 0, 0, 0), 0.5)], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 0)], - edges=[(to_label(cgraph, 1, 1, 0, 0, 0), - to_label(cgraph, 1, 0, 0, 0, 0), 0.5)], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), - time_stamp=fake_timestamp, n_threads=1) + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) - first_root = cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 0)) - assert first_root == cgraph.get_root(to_label(cgraph, 1, 1, 0, 0, 0)) + first_root = cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) + assert first_root == cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) timestamp_before_split = datetime.utcnow() - split_roots = cgraph.remove_edges("Jane Doe", - to_label(cgraph, 1, 0, 0, 0, 0), - to_label(cgraph, 1, 1, 0, 0, 0), - mincut=False).new_root_ids - + split_roots = cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + mincut=False, + ).new_root_ids assert len(split_roots) == 2 - timestamp_after_split = datetime.utcnow() - merge_roots = cgraph.add_edges("Jane Doe", - [to_label(cgraph, 1, 0, 0, 0, 0), - to_label(cgraph, 1, 1, 0, 0, 0)], - affinities=.4).new_root_ids + g = lineage_graph(cg, split_roots[0]) + assert g.size() == 1 + g = lineage_graph(cg, split_roots) + assert g.size() == 2 + + timestamp_after_split = datetime.utcnow() + merge_roots = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0)], + affinities=0.4, + ).new_root_ids assert len(merge_roots) == 1 merge_root = merge_roots[0] timestamp_after_merge = datetime.utcnow() - assert len(cgraph.get_root_id_history(first_root, - time_stamp_past=datetime.min, - time_stamp_future=datetime.max)) == 4 - assert len(cgraph.get_root_id_history(split_roots[0], - time_stamp_past=datetime.min, - time_stamp_future=datetime.max)) == 3 - assert len(cgraph.get_root_id_history(split_roots[1], - time_stamp_past=datetime.min, - time_stamp_future=datetime.max)) == 3 - assert len(cgraph.get_root_id_history(merge_root, - time_stamp_past=datetime.min, - time_stamp_future=datetime.max)) == 4 - - new_roots, old_roots = cgraph.get_delta_roots(timestamp_before_split, - timestamp_after_split) - assert(len(old_roots)==1) - assert(old_roots[0]==first_root) - assert(len(new_roots)==2) - assert(np.all(np.isin(new_roots, split_roots))) - - new_roots2, old_roots2 = cgraph.get_delta_roots(timestamp_after_split, - timestamp_after_merge) - assert(len(new_roots2)==1) - assert(new_roots2[0]==merge_root) - assert(len(old_roots2)==2) - assert(np.all(np.isin(old_roots2, split_roots))) - - new_roots3, old_roots3 = cgraph.get_delta_roots(timestamp_before_split, - timestamp_after_merge) - assert(len(new_roots3)==1) - assert(new_roots3[0]==merge_root) - assert(len(old_roots3)==1) - assert(old_roots3[0]==first_root) + g = lineage_graph(cg, merge_roots) + assert g.size() == 4 + assert ( + len( + get_root_id_history( + cg, + first_root, + time_stamp_past=datetime.min, + time_stamp_future=datetime.max, + ) + ) + == 4 + ) + assert ( + len( + get_root_id_history( + cg, + split_roots[0], + time_stamp_past=datetime.min, + time_stamp_future=datetime.max, + ) + ) + == 3 + ) + assert ( + len( + get_root_id_history( + cg, + split_roots[1], + time_stamp_past=datetime.min, + time_stamp_future=datetime.max, + ) + ) + == 3 + ) + assert ( + len( + get_root_id_history( + cg, + merge_root, + time_stamp_past=datetime.min, + time_stamp_future=datetime.max, + ) + ) + == 4 + ) + + new_roots, old_roots = get_delta_roots( + cg, timestamp_before_split, timestamp_after_split + ) + assert len(old_roots) == 1 + assert old_roots[0] == first_root + assert len(new_roots) == 2 + assert np.all(np.isin(new_roots, split_roots)) + + new_roots2, old_roots2 = get_delta_roots( + cg, timestamp_after_split, timestamp_after_merge + ) + assert len(new_roots2) == 1 + assert new_roots2[0] == merge_root + assert len(old_roots2) == 2 + assert np.all(np.isin(old_roots2, split_roots)) + + new_roots3, old_roots3 = get_delta_roots( + cg, timestamp_before_split, timestamp_after_merge + ) + assert len(new_roots3) == 1 + assert new_roots3[0] == merge_root + assert len(old_roots3) == 1 + assert old_roots3[0] == first_root class TestGraphLocks: @@ -2611,42 +2044,60 @@ def test_lock_unlock(self, gen_graph): (4) Try lock (opid = 2) """ - cgraph = gen_graph(n_layers=3) + cg = gen_graph(n_layers=3) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 1), - to_label(cgraph, 1, 0, 0, 0, 2)], - edges=[], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], + edges=[], + timestamp=fake_timestamp, + ) # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 1)], - edges=[], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 1)], + edges=[], + timestamp=fake_timestamp, + ) + + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), - time_stamp=fake_timestamp, n_threads=1) + operation_id_1 = cg.id_client.create_operation_id() + root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) - operation_id_1 = cgraph.get_unique_operation_id() - root_id = cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 1)) - assert cgraph.lock_root_loop(root_ids=[root_id], - operation_id=operation_id_1)[0] + future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} + assert cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_1, + future_root_ids_d=future_root_ids_d, + )[0] - operation_id_2 = cgraph.get_unique_operation_id() - assert not cgraph.lock_root_loop(root_ids=[root_id], - operation_id=operation_id_2)[0] + operation_id_2 = cg.id_client.create_operation_id() + assert not cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_2, + future_root_ids_d=future_root_ids_d, + )[0] - assert cgraph.unlock_root(root_id=root_id, - operation_id=operation_id_1) + assert cg.client.unlock_root(root_id=root_id, operation_id=operation_id_1) - assert cgraph.lock_root_loop(root_ids=[root_id], - operation_id=operation_id_2)[0] + assert cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_2, + future_root_ids_d=future_root_ids_d, + )[0] @pytest.mark.timeout(30) - def test_lock_expiration(self, gen_graph, lock_expired_timedelta_override): + def test_lock_expiration(self, gen_graph): """ No connection between 1, 2 and 3 ┌─────┬─────┐ @@ -2659,38 +2110,58 @@ def test_lock_expiration(self, gen_graph, lock_expired_timedelta_override): (2) Try lock (opid = 2) (3) Try lock (opid = 2) with retries """ - - cgraph = gen_graph(n_layers=3) + cg = gen_graph(n_layers=3) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 1), - to_label(cgraph, 1, 0, 0, 0, 2)], - edges=[], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], + edges=[], + timestamp=fake_timestamp, + ) # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 1)], - edges=[], - timestamp=fake_timestamp) - - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), - time_stamp=fake_timestamp, n_threads=1) - - operation_id_1 = cgraph.get_unique_operation_id() - root_id = cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 1)) - assert cgraph.lock_root_loop(root_ids=[root_id], - operation_id=operation_id_1)[0] + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 1)], + edges=[], + timestamp=fake_timestamp, + ) - operation_id_2 = cgraph.get_unique_operation_id() - assert not cgraph.lock_root_loop(root_ids=[root_id], - operation_id=operation_id_2)[0] + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) - assert cgraph.lock_root_loop(root_ids=[root_id], - operation_id=operation_id_2, - max_tries=10, waittime_s=.5)[0] + operation_id_1 = cg.id_client.create_operation_id() + root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) + future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} + assert cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_1, + future_root_ids_d=future_root_ids_d, + )[0] + + operation_id_2 = cg.id_client.create_operation_id() + assert not cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_2, + future_root_ids_d=future_root_ids_d, + )[0] + + sleep(cg.meta.graph_config.ROOT_LOCK_EXPIRY.total_seconds()) + + assert cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_2, + future_root_ids_d=future_root_ids_d, + max_tries=10, + waittime_s=0.5, + )[0] @pytest.mark.timeout(30) def test_lock_renew(self, gen_graph): @@ -2707,32 +2178,43 @@ def test_lock_renew(self, gen_graph): (3) Try lock (opid = 2) with retries """ - cgraph = gen_graph(n_layers=3) + cg = gen_graph(n_layers=3) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 1), - to_label(cgraph, 1, 0, 0, 0, 2)], - edges=[], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], + edges=[], + timestamp=fake_timestamp, + ) # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 1)], - edges=[], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 1)], + edges=[], + timestamp=fake_timestamp, + ) - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), - time_stamp=fake_timestamp, n_threads=1) + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) - operation_id_1 = cgraph.get_unique_operation_id() - root_id = cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 1)) - assert cgraph.lock_root_loop(root_ids=[root_id], - operation_id=operation_id_1)[0] + operation_id_1 = cg.id_client.create_operation_id() + root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) + future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} + assert cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_1, + future_root_ids_d=future_root_ids_d, + )[0] - assert cgraph.check_and_renew_root_locks(root_ids=[root_id], - operation_id=operation_id_1) + assert cg.client.renew_locks(root_ids=[root_id], operation_id=operation_id_1) @pytest.mark.timeout(30) def test_lock_merge_lock_old_id(self, gen_graph): @@ -2748,118 +2230,1315 @@ def test_lock_merge_lock_old_id(self, gen_graph): (2) Try lock opid 2 --> should be successful and return new root id """ - cgraph = gen_graph(n_layers=3) + cg = gen_graph(n_layers=3) # Preparation: Build Chunk A fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 0, 0, 0, 1), - to_label(cgraph, 1, 0, 0, 0, 2)], - edges=[], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], + edges=[], + timestamp=fake_timestamp, + ) # Preparation: Build Chunk B - create_chunk(cgraph, - vertices=[to_label(cgraph, 1, 1, 0, 0, 1)], - edges=[], - timestamp=fake_timestamp) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 1)], + edges=[], + timestamp=fake_timestamp, + ) - cgraph.add_layer(3, np.array([[0, 0, 0], [1, 0, 0]]), - time_stamp=fake_timestamp, n_threads=1) + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) - root_id = cgraph.get_root(to_label(cgraph, 1, 0, 0, 0, 1)) + root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) - new_root_ids = cgraph.add_edges("Chuck Norris", [to_label(cgraph, 1, 0, 0, 0, 1), - to_label(cgraph, 1, 0, 0, 0, 2)], affinities=1.).new_root_ids + new_root_ids = cg.add_edges( + "Chuck Norris", + [to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], + affinities=1.0, + ).new_root_ids assert new_root_ids is not None - operation_id_2 = cgraph.get_unique_operation_id() - success, new_root_id = cgraph.lock_root_loop(root_ids=[root_id], - operation_id=operation_id_2, - max_tries=10, waittime_s=.5) + operation_id_2 = cg.id_client.create_operation_id() + future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} + success, new_root_id = cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_2, + future_root_ids_d=future_root_ids_d, + max_tries=10, + waittime_s=0.5, + ) - cgraph.logger.debug(new_root_id) assert success assert new_root_ids[0] == new_root_id + @pytest.mark.timeout(30) + def test_indefinite_lock(self, gen_graph): + """ + No connection between 1, 2 and 3 + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1 │ 3 │ + │ 2 │ │ + └─────┴─────┘ -class MockChunkedGraph: - """ - Dummy class to mock partial functionality of the ChunkedGraph for use in unit tests. - Feel free to add more functions as need be. Can pass in alternative member functions into constructor. - """ + (1) Try indefinite lock (opid = 1), get indefinite lock + (2) Try normal lock (opid = 2), doesn't get the normal lock + (3) Try unlock indefinite lock (opid = 1), should unlock indefinite lock + (4) Try lock (opid = 2), should get the normal lock + """ - def __init__( - self, get_chunk_coordinates=None, get_chunk_layer=None, get_chunk_id=None - ): - if get_chunk_coordinates is not None: - self.get_chunk_coordinates = get_chunk_coordinates - if get_chunk_layer is not None: - self.get_chunk_layer = get_chunk_layer - if get_chunk_id is not None: - self.get_chunk_id = get_chunk_id + cg = gen_graph(n_layers=3) - def get_chunk_coordinates(self, chunk_id): - return np.array([0, 0, 0]) + # Preparation: Build Chunk A + fake_timestamp = datetime.utcnow() - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], + edges=[], + timestamp=fake_timestamp, + ) - def get_chunk_layer(self, chunk_id): - return 2 + # Preparation: Build Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 1)], + edges=[], + timestamp=fake_timestamp, + ) - def get_chunk_id(self, *args): - return 0 + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + operation_id_1 = cg.id_client.create_operation_id() + root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) + + future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} + assert cg.client.lock_roots_indefinitely( + root_ids=[root_id], + operation_id=operation_id_1, + future_root_ids_d=future_root_ids_d, + )[0] + + operation_id_2 = cg.id_client.create_operation_id() + assert not cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_2, + future_root_ids_d=future_root_ids_d, + )[0] + + assert cg.client.unlock_indefinitely_locked_root( + root_id=root_id, operation_id=operation_id_1 + ) + + assert cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_2, + future_root_ids_d=future_root_ids_d, + )[0] -class TestMeshes: @pytest.mark.timeout(30) - @mock.patch( - "pychunkedgraph.meshing.meshgen.get_meshing_necessities_from_graph", - return_value=(0, 0, np.array([1, 12, 5])), - ) - @mock.patch( - "pychunkedgraph.meshing.meshgen.get_draco_encoding_settings_for_chunk", - return_value={ - "quantization_bits": 3, - "quantization_range": 21, - "quantization_origin": np.array([-1, 11, 3]), - }, - ) - def test_merge_draco_meshes_across_boundaries(self, *args): + def test_indefinite_lock_with_normal_lock_expiration(self, gen_graph): """ - Test merging a list of quantized draco meshes across a chunk boundary. - In meshgen the quantization parameters are determined using characteristics - of the chunk the mesh comes from, but here they are mocked out. + No connection between 1, 2 and 3 + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1 │ 3 │ + │ 2 │ │ + └─────┴─────┘ + + (1) Try normal lock (opid = 1), get normal lock + (2) Try indefinite lock (opid = 1), get indefinite lock + (3) Wait until normal lock expires + (4) Try normal lock (opid = 2), doesn't get the normal lock + (5) Try unlock indefinite lock (opid = 1), should unlock indefinite lock + (6) Try lock (opid = 2), should get the normal lock """ - mock_cg = MockChunkedGraph() - fragments = [ - { - "mesh": { - "num_vertices": 4, - "vertices": np.array( - [[7, 15, 4], [2, 10, 9], [9, 11, 9], [3, 7, 6]] - ), - "faces": np.array([0, 1, 2, 1, 2, 3]), - } - }, - { - "mesh": { - "num_vertices": 5, - "vertices": np.array( - [[10, 10, 10], [2, 10, 9], [7, 15, 4], [9, 11, 9], [3, 7, 6]] - ), - "faces": np.array([0, 1, 2, 0, 1, 3, 0, 2, 3, 1, 2, 3, 0, 1, 4]), - } - }, - ] - merged_vertices = meshgen.merge_draco_meshes_across_boundaries( - mock_cg, fragments, 0, 0, 0 + + # 1. TODO renew lock test when getting indefinite lock + cg = gen_graph(n_layers=3) + + # Preparation: Build Chunk A + fake_timestamp = datetime.utcnow() - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], + edges=[], + timestamp=fake_timestamp, ) - expected_vertices = np.array( - [7, 15, 4, 10, 10, 10, 7, 15, 4, 2, 10, 9, 3, 7, 6, 9, 11, 9] + + # Preparation: Build Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 1)], + edges=[], + timestamp=fake_timestamp, ) - expected_faces = np.array( - [0, 3, 5, 3, 5, 4, 1, 3, 2, 1, 3, 5, 1, 2, 5, 3, 2, 5, 1, 3, 4] + + add_layer( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + + operation_id_1 = cg.id_client.create_operation_id() + root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) + + future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} + + assert cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_1, + future_root_ids_d=future_root_ids_d, + )[0] + + assert cg.client.lock_roots_indefinitely( + root_ids=[root_id], + operation_id=operation_id_1, + future_root_ids_d=future_root_ids_d, + )[0] + + sleep(cg.meta.graph_config.ROOT_LOCK_EXPIRY.total_seconds()) + + operation_id_2 = cg.id_client.create_operation_id() + assert not cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_2, + future_root_ids_d=future_root_ids_d, + )[0] + + assert cg.client.unlock_indefinitely_locked_root( + root_id=root_id, operation_id=operation_id_1 ) - assert merged_vertices["num_vertices"] == 6 - assert np.array_equal(merged_vertices["vertices"], expected_vertices) - assert np.array_equal(merged_vertices["faces"], expected_faces) + + assert cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_2, + future_root_ids_d=future_root_ids_d, + )[0] + + # TODO fixme: this scenario can't be tested like this + # @pytest.mark.timeout(30) + # def test_normal_lock_expiration(self, gen_graph): + # """ + # No connection between 1, 2 and 3 + # ┌─────┬─────┐ + # │ A¹ │ B¹ │ + # │ 1 │ 3 │ + # │ 2 │ │ + # └─────┴─────┘ + + # (1) Try normal lock (opid = 1), get normal lock + # (2) Wait until normal lock expires + # (3) Try indefinite lock (opid = 1), doesn't get the indefinite lock + # """ + + # cg = gen_graph(n_layers=3) + + # # Preparation: Build Chunk A + # fake_timestamp = datetime.utcnow() - timedelta(days=10) + # create_chunk( + # cg, + # vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], + # edges=[], + # timestamp=fake_timestamp, + # ) + + # # Preparation: Build Chunk B + # create_chunk( + # cg, + # vertices=[to_label(cg, 1, 1, 0, 0, 1)], + # edges=[], + # timestamp=fake_timestamp, + # ) + + # add_layer( + # cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1, + # ) + + # operation_id_1 = cg.id_client.create_operation_id() + # root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) + + # future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} + + # assert cg.client.lock_roots( + # root_ids=[root_id], + # operation_id=operation_id_1, + # future_root_ids_d=future_root_ids_d, + # )[0] + + # sleep(cg.meta.graph_config.ROOT_LOCK_EXPIRY.total_seconds()+1) + + # assert not cg.client.lock_roots_indefinitely( + # root_ids=[root_id], + # operation_id=operation_id_1, + # future_root_ids_d=future_root_ids_d, + # )[0] + + +# class MockChunkedGraph: +# """ +# Dummy class to mock partial functionality of the ChunkedGraph for use in unit tests. +# Feel free to add more functions as need be. Can pass in alternative member functions into constructor. +# """ + +# def __init__( +# self, get_chunk_coordinates=None, get_chunk_layer=None, get_chunk_id=None +# ): +# if get_chunk_coordinates is not None: +# self.get_chunk_coordinates = get_chunk_coordinates +# if get_chunk_layer is not None: +# self.get_chunk_layer = get_chunk_layer +# if get_chunk_id is not None: +# self.get_chunk_id = get_chunk_id + +# def get_chunk_coordinates(self, chunk_id): # pylint: disable=method-hidden +# return np.array([0, 0, 0]) + +# def get_chunk_layer(self, chunk_id): # pylint: disable=method-hidden +# return 2 + +# def get_chunk_id(self, *args): # pylint: disable=method-hidden +# return 0 + + +# class TestGraphSplit: +# @pytest.mark.timeout(30) +# def test_split_pair_same_chunk(self, gen_graph): +# """ +# Remove edge between existing RG supervoxels 1 and 2 (same chunk) +# Expected: Different (new) parents for RG 1 and 2 on Layer two +# ┌─────┐ ┌─────┐ +# │ A¹ │ │ A¹ │ +# │ 1━2 │ => │ 1 2 │ +# │ │ │ │ +# └─────┘ └─────┘ +# """ + +# cg = gen_graph(n_layers=2) + +# # Preparation: Build Chunk A +# fake_timestamp = datetime.utcnow() - timedelta(days=10) +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], +# edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5)], +# timestamp=fake_timestamp, +# ) + +# # Split +# new_root_ids = cg.remove_edges( +# "Jane Doe", +# source_ids=to_label(cg, 1, 0, 0, 0, 1), +# sink_ids=to_label(cg, 1, 0, 0, 0, 0), +# mincut=False, +# ).new_root_ids + +# # Check New State +# assert len(new_root_ids) == 2 +# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) != cg.get_root( +# to_label(cg, 1, 0, 0, 0, 1) +# ) +# leaves = np.unique( +# cg.get_subgraph([cg.get_root(to_label(cg, 1, 0, 0, 0, 0))], leaves_only=True) +# ) +# assert len(leaves) == 1 and to_label(cg, 1, 0, 0, 0, 0) in leaves +# leaves = np.unique( +# cg.get_subgraph([cg.get_root(to_label(cg, 1, 0, 0, 0, 1))], leaves_only=True) +# ) +# assert len(leaves) == 1 and to_label(cg, 1, 0, 0, 0, 1) in leaves + +# # Check Old State still accessible +# assert cg.get_root( +# to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp +# ) == cg.get_root(to_label(cg, 1, 0, 0, 0, 1), time_stamp=fake_timestamp) +# leaves = np.unique( +# cg.get_subgraph( +# [cg.get_root(to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp)], +# leaves_only=True, +# ) +# ) +# assert len(leaves) == 2 +# assert to_label(cg, 1, 0, 0, 0, 0) in leaves +# assert to_label(cg, 1, 0, 0, 0, 1) in leaves + +# # assert len(cg.get_latest_roots()) == 2 +# # assert len(cg.get_latest_roots(fake_timestamp)) == 1 + +# def test_split_nonexisting_edge(self, gen_graph): +# """ +# Remove edge between existing RG supervoxels 1 and 2 (same chunk) +# Expected: Different (new) parents for RG 1 and 2 on Layer two +# ┌─────┐ ┌─────┐ +# │ A¹ │ │ A¹ │ +# │ 1━2 │ => │ 1━2 │ +# │ | │ │ | │ +# │ 3 │ │ 3 │ +# └─────┘ └─────┘ +# """ + +# cg = gen_graph(n_layers=2) + +# # Preparation: Build Chunk A +# fake_timestamp = datetime.utcnow() - timedelta(days=10) +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], +# edges=[ +# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), +# (to_label(cg, 1, 0, 0, 0, 2), to_label(cg, 1, 0, 0, 0, 1), 0.5), +# ], +# timestamp=fake_timestamp, +# ) + +# # Split +# new_root_ids = cg.remove_edges( +# "Jane Doe", +# source_ids=to_label(cg, 1, 0, 0, 0, 0), +# sink_ids=to_label(cg, 1, 0, 0, 0, 2), +# mincut=False, +# ).new_root_ids + +# assert len(new_root_ids) == 1 + +# @pytest.mark.timeout(30) +# def test_split_pair_neighboring_chunks(self, gen_graph): +# """ +# Remove edge between existing RG supervoxels 1 and 2 (neighboring chunks) +# ┌─────┬─────┐ ┌─────┬─────┐ +# │ A¹ │ B¹ │ │ A¹ │ B¹ │ +# │ 1━━┿━━2 │ => │ 1 │ 2 │ +# │ │ │ │ │ │ +# └─────┴─────┘ └─────┴─────┘ +# """ + +# cg = gen_graph(n_layers=3) + +# # Preparation: Build Chunk A +# fake_timestamp = datetime.utcnow() - timedelta(days=10) +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 0, 0, 0, 0)], +# edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 1.0)], +# timestamp=fake_timestamp, +# ) + +# # Preparation: Build Chunk B +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 1, 0, 0, 0)], +# edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 1.0)], +# timestamp=fake_timestamp, +# ) + +# add_layer( +# cg, +# 3, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) + +# # Split +# new_root_ids = cg.remove_edges( +# "Jane Doe", +# source_ids=to_label(cg, 1, 1, 0, 0, 0), +# sink_ids=to_label(cg, 1, 0, 0, 0, 0), +# mincut=False, +# ).new_root_ids + +# # Check New State +# assert len(new_root_ids) == 2 +# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) != cg.get_root( +# to_label(cg, 1, 1, 0, 0, 0) +# ) +# leaves = np.unique( +# cg.get_subgraph([cg.get_root(to_label(cg, 1, 0, 0, 0, 0))], leaves_only=True) +# ) +# assert len(leaves) == 1 and to_label(cg, 1, 0, 0, 0, 0) in leaves +# leaves = np.unique( +# cg.get_subgraph([cg.get_root(to_label(cg, 1, 1, 0, 0, 0))], leaves_only=True) +# ) +# assert len(leaves) == 1 and to_label(cg, 1, 1, 0, 0, 0) in leaves + +# # Check Old State still accessible +# assert cg.get_root( +# to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp +# ) == cg.get_root(to_label(cg, 1, 1, 0, 0, 0), time_stamp=fake_timestamp) +# leaves = np.unique( +# cg.get_subgraph( +# [cg.get_root(to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp)], +# leaves_only=True, +# ) +# ) +# assert len(leaves) == 2 +# assert to_label(cg, 1, 0, 0, 0, 0) in leaves +# assert to_label(cg, 1, 1, 0, 0, 0) in leaves + +# assert len(cg.get_latest_roots()) == 2 +# assert len(cg.get_latest_roots(fake_timestamp)) == 1 + +# @pytest.mark.timeout(30) +# def test_split_verify_cross_chunk_edges(self, gen_graph): +# """ +# Remove edge between existing RG supervoxels 1 and 2 (neighboring chunks) +# ┌─────┬─────┬─────┐ ┌─────┬─────┬─────┐ +# | │ A¹ │ B¹ │ | │ A¹ │ B¹ │ +# | │ 1━━┿━━3 │ => | │ 1━━┿━━3 │ +# | │ | │ │ | │ │ │ +# | │ 2 │ │ | │ 2 │ │ +# └─────┴─────┴─────┘ └─────┴─────┴─────┘ +# """ + +# cg = gen_graph(n_layers=4) + +# # Preparation: Build Chunk A +# fake_timestamp = datetime.utcnow() - timedelta(days=10) +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1)], +# edges=[ +# (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 0), inf), +# (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1), 0.5), +# ], +# timestamp=fake_timestamp, +# ) + +# # Preparation: Build Chunk B +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 2, 0, 0, 0)], +# edges=[(to_label(cg, 1, 2, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf)], +# timestamp=fake_timestamp, +# ) + +# add_layer( +# cg, +# 3, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# add_layer( +# cg, +# 3, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# add_layer( +# cg, +# 4, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) + +# assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == cg.get_root( +# to_label(cg, 1, 1, 0, 0, 1) +# ) +# assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == cg.get_root( +# to_label(cg, 1, 2, 0, 0, 0) +# ) + +# # Split +# new_root_ids = cg.remove_edges( +# "Jane Doe", +# source_ids=to_label(cg, 1, 1, 0, 0, 0), +# sink_ids=to_label(cg, 1, 1, 0, 0, 1), +# mincut=False, +# ).new_root_ids + +# assert len(new_root_ids) == 2 + +# svs2 = cg.get_subgraph([new_root_ids[0]], leaves_only=True) +# svs1 = cg.get_subgraph([new_root_ids[1]], leaves_only=True) +# len_set = {1, 2} +# assert len(svs1) in len_set +# len_set.remove(len(svs1)) +# assert len(svs2) in len_set + +# # Check New State +# assert len(new_root_ids) == 2 +# assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) != cg.get_root( +# to_label(cg, 1, 1, 0, 0, 1) +# ) +# assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == cg.get_root( +# to_label(cg, 1, 2, 0, 0, 0) +# ) + +# cc_dict = cg.get_atomic_cross_edges( +# cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) +# ) +# assert len(cc_dict[3]) == 1 +# assert cc_dict[3][0][0] == to_label(cg, 1, 1, 0, 0, 0) +# assert cc_dict[3][0][1] == to_label(cg, 1, 2, 0, 0, 0) + +# assert len(cg.get_latest_roots()) == 2 +# assert len(cg.get_latest_roots(fake_timestamp)) == 1 + +# @pytest.mark.timeout(30) +# def test_split_verify_loop(self, gen_graph): +# """ +# Remove edge between existing RG supervoxels 1 and 2 (neighboring chunks) +# ┌─────┬────────┬─────┐ ┌─────┬────────┬─────┐ +# | │ A¹ │ B¹ │ | │ A¹ │ B¹ │ +# | │ 4━━1━━┿━━5 │ => | │ 4 1━━┿━━5 │ +# | │ / │ | │ | │ │ | │ +# | │ 3 2━━┿━━6 │ | │ 3 2━━┿━━6 │ +# └─────┴────────┴─────┘ └─────┴────────┴─────┘ +# """ + +# cg = gen_graph(n_layers=4) + +# # Preparation: Build Chunk A +# fake_timestamp = datetime.utcnow() - timedelta(days=10) +# create_chunk( +# cg, +# vertices=[ +# to_label(cg, 1, 1, 0, 0, 0), +# to_label(cg, 1, 1, 0, 0, 1), +# to_label(cg, 1, 1, 0, 0, 2), +# to_label(cg, 1, 1, 0, 0, 3), +# ], +# edges=[ +# (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 0), inf), +# (to_label(cg, 1, 1, 0, 0, 1), to_label(cg, 1, 2, 0, 0, 1), inf), +# (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 2), 0.5), +# (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 3), 0.5), +# ], +# timestamp=fake_timestamp, +# ) + +# # Preparation: Build Chunk B +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 2, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 1)], +# edges=[ +# (to_label(cg, 1, 2, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf), +# (to_label(cg, 1, 2, 0, 0, 1), to_label(cg, 1, 1, 0, 0, 1), inf), +# (to_label(cg, 1, 2, 0, 0, 1), to_label(cg, 1, 2, 0, 0, 0), 0.5), +# ], +# timestamp=fake_timestamp, +# ) + +# add_layer( +# cg, +# 3, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# add_layer( +# cg, +# 3, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# add_layer( +# cg, +# 4, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) + +# assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == cg.get_root( +# to_label(cg, 1, 1, 0, 0, 1) +# ) +# assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == cg.get_root( +# to_label(cg, 1, 2, 0, 0, 0) +# ) + +# # Split +# new_root_ids = cg.remove_edges( +# "Jane Doe", +# source_ids=to_label(cg, 1, 1, 0, 0, 0), +# sink_ids=to_label(cg, 1, 1, 0, 0, 2), +# mincut=False, +# ).new_root_ids + +# assert len(new_root_ids) == 2 + +# new_root_ids = cg.remove_edges( +# "Jane Doe", +# source_ids=to_label(cg, 1, 1, 0, 0, 0), +# sink_ids=to_label(cg, 1, 1, 0, 0, 3), +# mincut=False, +# ).new_root_ids + +# assert len(new_root_ids) == 2 + +# cc_dict = cg.get_atomic_cross_edges( +# cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) +# ) +# assert len(cc_dict[3]) == 1 +# cc_dict = cg.get_atomic_cross_edges( +# cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) +# ) +# assert len(cc_dict[3]) == 1 + +# assert len(cg.get_latest_roots()) == 3 +# assert len(cg.get_latest_roots(fake_timestamp)) == 1 + +# @pytest.mark.timeout(30) +# def test_split_pair_disconnected_chunks(self, gen_graph): +# """ +# Remove edge between existing RG supervoxels 1 and 2 (disconnected chunks) +# ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ +# │ A¹ │ ... │ Z¹ │ │ A¹ │ ... │ Z¹ │ +# │ 1━━┿━━━━━┿━━2 │ => │ 1 │ │ 2 │ +# │ │ │ │ │ │ │ │ +# └─────┘ └─────┘ └─────┘ └─────┘ +# """ + +# cg = gen_graph(n_layers=9) + +# # Preparation: Build Chunk A +# fake_timestamp = datetime.utcnow() - timedelta(days=10) +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 0, 0, 0, 0)], +# edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 7, 7, 7, 0), 1.0,)], +# timestamp=fake_timestamp, +# ) + +# # Preparation: Build Chunk Z +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 7, 7, 7, 0)], +# edges=[(to_label(cg, 1, 7, 7, 7, 0), to_label(cg, 1, 0, 0, 0, 0), 1.0,)], +# timestamp=fake_timestamp, +# ) + +# add_layer( +# cg, +# 3, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# add_layer( +# cg, +# 3, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# add_layer( +# cg, +# 4, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# add_layer( +# cg, +# 4, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# add_layer( +# cg, +# 5, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# add_layer( +# cg, +# 5, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# add_layer( +# cg, +# 6, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# add_layer( +# cg, +# 6, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# add_layer( +# cg, +# 7, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# add_layer( +# cg, +# 7, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# add_layer( +# cg, +# 8, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# add_layer( +# cg, +# 8, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# add_layer( +# cg, +# 9, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) + +# # Split +# new_roots = cg.remove_edges( +# "Jane Doe", +# source_ids=to_label(cg, 1, 7, 7, 7, 0), +# sink_ids=to_label(cg, 1, 0, 0, 0, 0), +# mincut=False, +# ).new_root_ids + +# # Check New State +# assert len(new_roots) == 2 +# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) != cg.get_root( +# to_label(cg, 1, 7, 7, 7, 0) +# ) +# leaves = np.unique( +# cg.get_subgraph([cg.get_root(to_label(cg, 1, 0, 0, 0, 0))], leaves_only=True) +# ) +# assert len(leaves) == 1 and to_label(cg, 1, 0, 0, 0, 0) in leaves +# leaves = np.unique( +# cg.get_subgraph([cg.get_root(to_label(cg, 1, 7, 7, 7, 0))], leaves_only=True) +# ) +# assert len(leaves) == 1 and to_label(cg, 1, 7, 7, 7, 0) in leaves + +# # Check Old State still accessible +# assert cg.get_root( +# to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp +# ) == cg.get_root(to_label(cg, 1, 7, 7, 7, 0), time_stamp=fake_timestamp) +# leaves = np.unique( +# cg.get_subgraph( +# [cg.get_root(to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp)], +# leaves_only=True, +# ) +# ) +# assert len(leaves) == 2 +# assert to_label(cg, 1, 0, 0, 0, 0) in leaves +# assert to_label(cg, 1, 7, 7, 7, 0) in leaves + +# @pytest.mark.timeout(30) +# def test_split_pair_already_disconnected(self, gen_graph): +# """ +# Try to remove edge between already disconnected RG supervoxels 1 and 2 (same chunk). +# Expected: No change, no error +# ┌─────┐ ┌─────┐ +# │ A¹ │ │ A¹ │ +# │ 1 2 │ => │ 1 2 │ +# │ │ │ │ +# └─────┘ └─────┘ +# """ + +# cg = gen_graph(n_layers=2) + +# # Preparation: Build Chunk A +# fake_timestamp = datetime.utcnow() - timedelta(days=10) +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], +# edges=[], +# timestamp=fake_timestamp, +# ) + +# res_old = cg.client._table.read_rows() +# res_old.consume_all() + +# # Split +# with pytest.raises(exceptions.PreconditionError): +# cg.remove_edges( +# "Jane Doe", +# source_ids=to_label(cg, 1, 0, 0, 0, 1), +# sink_ids=to_label(cg, 1, 0, 0, 0, 0), +# mincut=False, +# ) + +# res_new = cg.client._table.read_rows() +# res_new.consume_all() + +# # Check +# if res_old.rows != res_new.rows: +# warn( +# "Rows were modified when splitting a pair of already disconnected supervoxels. " +# "While probably not an error, it is an unnecessary operation." +# ) + +# @pytest.mark.timeout(30) +# def test_split_full_circle_to_triple_chain_same_chunk(self, gen_graph): +# """ +# Remove direct edge between RG supervoxels 1 and 2, but leave indirect connection (same chunk) +# ┌─────┐ ┌─────┐ +# │ A¹ │ │ A¹ │ +# │ 1━2 │ => │ 1 2 │ +# │ ┗3┛ │ │ ┗3┛ │ +# └─────┘ └─────┘ +# """ + +# cg = gen_graph(n_layers=2) + +# # Preparation: Build Chunk A +# fake_timestamp = datetime.utcnow() - timedelta(days=10) +# create_chunk( +# cg, +# vertices=[ +# to_label(cg, 1, 0, 0, 0, 0), +# to_label(cg, 1, 0, 0, 0, 1), +# to_label(cg, 1, 0, 0, 0, 2), +# ], +# edges=[ +# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 2), 0.5), +# (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2), 0.5), +# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.3), +# ], +# timestamp=fake_timestamp, +# ) + +# # Split +# new_root_ids = cg.remove_edges( +# "Jane Doe", +# source_ids=to_label(cg, 1, 0, 0, 0, 1), +# sink_ids=to_label(cg, 1, 0, 0, 0, 0), +# mincut=False, +# ).new_root_ids + +# # Check New State +# assert len(new_root_ids) == 1 +# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == new_root_ids[0] +# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) == new_root_ids[0] +# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 2)) == new_root_ids[0] +# leaves = np.unique(cg.get_subgraph([new_root_ids[0]], leaves_only=True)) +# assert len(leaves) == 3 +# assert to_label(cg, 1, 0, 0, 0, 0) in leaves +# assert to_label(cg, 1, 0, 0, 0, 1) in leaves +# assert to_label(cg, 1, 0, 0, 0, 2) in leaves + +# # Check Old State still accessible +# old_root_id = cg.get_root( +# to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp +# ) +# assert new_root_ids[0] != old_root_id + +# # assert len(cg.get_latest_roots()) == 1 +# # assert len(cg.get_latest_roots(fake_timestamp)) == 1 + +# @pytest.mark.timeout(30) +# def test_split_full_circle_to_triple_chain_neighboring_chunks(self, gen_graph): +# """ +# Remove direct edge between RG supervoxels 1 and 2, but leave indirect connection (neighboring chunks) +# ┌─────┬─────┐ ┌─────┬─────┐ +# │ A¹ │ B¹ │ │ A¹ │ B¹ │ +# │ 1━━┿━━2 │ => │ 1 │ 2 │ +# │ ┗3━┿━━┛ │ │ ┗3━┿━━┛ │ +# └─────┴─────┘ └─────┴─────┘ +# """ + +# cg = gen_graph(n_layers=3) + +# # Preparation: Build Chunk A +# fake_timestamp = datetime.utcnow() - timedelta(days=10) +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], +# edges=[ +# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), +# (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 1, 0, 0, 0), 0.5), +# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.3), +# ], +# timestamp=fake_timestamp, +# ) + +# # Preparation: Build Chunk B +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 1, 0, 0, 0)], +# edges=[ +# (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), +# (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.3), +# ], +# timestamp=fake_timestamp, +# ) + +# add_layer( +# cg, +# 3, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) + +# # Split +# new_root_ids = cg.remove_edges( +# "Jane Doe", +# source_ids=to_label(cg, 1, 1, 0, 0, 0), +# sink_ids=to_label(cg, 1, 0, 0, 0, 0), +# mincut=False, +# ).new_root_ids + +# # Check New State +# assert len(new_root_ids) == 1 +# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == new_root_ids[0] +# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) == new_root_ids[0] +# assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == new_root_ids[0] +# leaves = np.unique(cg.get_subgraph([new_root_ids[0]], leaves_only=True)) +# assert len(leaves) == 3 +# assert to_label(cg, 1, 0, 0, 0, 0) in leaves +# assert to_label(cg, 1, 0, 0, 0, 1) in leaves +# assert to_label(cg, 1, 1, 0, 0, 0) in leaves + +# # Check Old State still accessible +# old_root_id = cg.get_root( +# to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp +# ) +# assert new_root_ids[0] != old_root_id + +# assert len(cg.get_latest_roots()) == 1 +# assert len(cg.get_latest_roots(fake_timestamp)) == 1 + +# @pytest.mark.timeout(30) +# def test_split_full_circle_to_triple_chain_disconnected_chunks(self, gen_graph): +# """ +# Remove direct edge between RG supervoxels 1 and 2, but leave indirect connection (disconnected chunks) +# ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ +# │ A¹ │ ... │ Z¹ │ │ A¹ │ ... │ Z¹ │ +# │ 1━━┿━━━━━┿━━2 │ => │ 1 │ │ 2 │ +# │ ┗3━┿━━━━━┿━━┛ │ │ ┗3━┿━━━━━┿━━┛ │ +# └─────┘ └─────┘ └─────┘ └─────┘ +# """ + +# cg = gen_graph(n_layers=9) + +# loc = 2 + +# # Preparation: Build Chunk A +# fake_timestamp = datetime.utcnow() - timedelta(days=10) +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], +# edges=[ +# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), +# (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, loc, loc, loc, 0), 0.5,), +# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, loc, loc, loc, 0), 0.3,), +# ], +# timestamp=fake_timestamp, +# ) + +# # Preparation: Build Chunk Z +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, loc, loc, loc, 0)], +# edges=[ +# (to_label(cg, 1, loc, loc, loc, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5,), +# (to_label(cg, 1, loc, loc, loc, 0), to_label(cg, 1, 0, 0, 0, 0), 0.3,), +# ], +# timestamp=fake_timestamp, +# ) + +# for i_layer in range(3, 10): +# if loc // 2 ** (i_layer - 3) == 1: +# add_layer( +# cg, +# i_layer, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# elif loc // 2 ** (i_layer - 3) == 0: +# add_layer( +# cg, +# i_layer, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# else: +# add_layer( +# cg, +# i_layer, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) +# add_layer( +# cg, +# i_layer, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) + +# assert ( +# cg.get_root(to_label(cg, 1, loc, loc, loc, 0)) +# == cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) +# == cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) +# ) + +# # Split +# new_root_ids = cg.remove_edges( +# "Jane Doe", +# source_ids=to_label(cg, 1, loc, loc, loc, 0), +# sink_ids=to_label(cg, 1, 0, 0, 0, 0), +# mincut=False, +# ).new_root_ids + +# # Check New State +# assert len(new_root_ids) == 1 +# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == new_root_ids[0] +# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) == new_root_ids[0] +# assert cg.get_root(to_label(cg, 1, loc, loc, loc, 0)) == new_root_ids[0] +# leaves = np.unique(cg.get_subgraph([new_root_ids[0]], leaves_only=True)) +# assert len(leaves) == 3 +# assert to_label(cg, 1, 0, 0, 0, 0) in leaves +# assert to_label(cg, 1, 0, 0, 0, 1) in leaves +# assert to_label(cg, 1, loc, loc, loc, 0) in leaves + +# # Check Old State still accessible +# old_root_id = cg.get_root( +# to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp +# ) +# assert new_root_ids[0] != old_root_id + +# assert len(cg.get_latest_roots()) == 1 +# assert len(cg.get_latest_roots(fake_timestamp)) == 1 + +# @pytest.mark.timeout(30) +# def test_split_same_node(self, gen_graph): +# """ +# Try to remove (non-existing) edge between RG supervoxel 1 and itself +# ┌─────┐ +# │ A¹ │ +# │ 1 │ => Reject +# │ │ +# └─────┘ +# """ + +# cg = gen_graph(n_layers=2) + +# # Preparation: Build Chunk A +# fake_timestamp = datetime.utcnow() - timedelta(days=10) +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 0, 0, 0, 0)], +# edges=[], +# timestamp=fake_timestamp, +# ) + +# res_old = cg.client._table.read_rows() +# res_old.consume_all() + +# # Split +# with pytest.raises(exceptions.PreconditionError): +# cg.remove_edges( +# "Jane Doe", +# source_ids=to_label(cg, 1, 0, 0, 0, 0), +# sink_ids=to_label(cg, 1, 0, 0, 0, 0), +# mincut=False, +# ) + +# res_new = cg.client._table.read_rows() +# res_new.consume_all() + +# assert res_new.rows == res_old.rows + +# @pytest.mark.timeout(30) +# def test_split_pair_abstract_nodes(self, gen_graph): +# """ +# Try to remove (non-existing) edge between RG supervoxel 1 and abstract node "2" +# ┌─────┐ +# │ B² │ +# │ "2" │ +# │ │ +# └─────┘ +# ┌─────┐ => Reject +# │ A¹ │ +# │ 1 │ +# │ │ +# └─────┘ +# """ + +# cg = gen_graph(n_layers=3) + +# # Preparation: Build Chunk A +# fake_timestamp = datetime.utcnow() - timedelta(days=10) +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 0, 0, 0, 0)], +# edges=[], +# timestamp=fake_timestamp, +# ) + +# # Preparation: Build Chunk B +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 1, 0, 0, 0)], +# edges=[], +# timestamp=fake_timestamp, +# ) + +# add_layer( +# cg, +# 3, +# [0, 0, 0], +# +# time_stamp=fake_timestamp, +# n_threads=1, +# ) + +# res_old = cg.client._table.read_rows() +# res_old.consume_all() + +# # Split +# with pytest.raises(exceptions.PreconditionError): +# cg.remove_edges( +# "Jane Doe", +# source_ids=to_label(cg, 1, 0, 0, 0, 0), +# sink_ids=to_label(cg, 2, 1, 0, 0, 1), +# mincut=False, +# ) + +# res_new = cg.client._table.read_rows() +# res_new.consume_all() + +# assert res_new.rows == res_old.rows + +# @pytest.mark.timeout(30) +# def test_diagonal_connections(self, gen_graph): +# """ +# Create graph with edge between RG supervoxels 1 and 2 (same chunk) +# and edge between RG supervoxels 1 and 3 (neighboring chunks) +# ┌─────┬─────┐ +# │ A¹ │ B¹ │ +# │ 2━1━┿━━3 │ +# │ / │ │ +# ┌─────┬─────┐ +# │ | │ │ +# │ 4━━┿━━5 │ +# │ C¹ │ D¹ │ +# └─────┴─────┘ +# """ + +# cg = gen_graph(n_layers=3) + +# # Chunk A +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], +# edges=[ +# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), +# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf), +# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 1, 0, 0), inf), +# ], +# ) + +# # Chunk B +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 1, 0, 0, 0)], +# edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf)], +# ) + +# # Chunk C +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 0, 1, 0, 0)], +# edges=[ +# (to_label(cg, 1, 0, 1, 0, 0), to_label(cg, 1, 1, 1, 0, 0), inf), +# (to_label(cg, 1, 0, 1, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf), +# ], +# ) + +# # Chunk D +# create_chunk( +# cg, +# vertices=[to_label(cg, 1, 1, 1, 0, 0)], +# edges=[(to_label(cg, 1, 1, 1, 0, 0), to_label(cg, 1, 0, 1, 0, 0), inf)], +# ) + +# add_layer( +# cg, 3, [0, 0, 0], n_threads=1, +# ) + +# rr = cg.range_read_chunk(chunk_id=cg.get_chunk_id(layer=3, x=0, y=0, z=0)) +# root_ids_t0 = list(rr.keys()) + +# assert len(root_ids_t0) == 1 + +# child_ids = [] +# for root_id in root_ids_t0: +# child_ids.extend([cg.get_subgraph([root_id])], leaves_only=True) + +# new_roots = cg.remove_edges( +# "Jane Doe", +# source_ids=to_label(cg, 1, 0, 0, 0, 0), +# sink_ids=to_label(cg, 1, 0, 0, 0, 1), +# mincut=False, +# ).new_root_ids + +# assert len(new_roots) == 2 +# assert cg.get_root(to_label(cg, 1, 1, 1, 0, 0)) == cg.get_root( +# to_label(cg, 1, 0, 1, 0, 0) +# ) +# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == cg.get_root( +# to_label(cg, 1, 0, 0, 0, 0) +# ) diff --git a/pychunkedgraph/utils/general.py b/pychunkedgraph/utils/general.py index 108762a2b..71e24eab0 100644 --- a/pychunkedgraph/utils/general.py +++ b/pychunkedgraph/utils/general.py @@ -1,22 +1,41 @@ -import redis -import functools - -def redis_job(redis_url, redis_channel): - ''' - Decorator factory - Returns a decorator that connects to a redis instance - and publish a message (return value of the function) when the job is done. - ''' - def redis_job_decorator(func): - r = redis.Redis.from_url(redis_url) - - @functools.wraps(func) - def wrapper(*args, **kwargs): - job_result = func(*args, **kwargs) - if not job_result: - job_result = str(job_result) - r.publish(redis_channel, job_result) - - return wrapper - - return redis_job_decorator \ No newline at end of file +""" +generic helper funtions +""" +from typing import Sequence +from itertools import islice + +import numpy as np + + +def reverse_dictionary(dictionary): + """ + given a dictionary - {key1 : [item1, item2 ...], key2 : [ite3, item4 ...]} + return {item1: key1, item2: key1, item3: key2, item4: key2 } + """ + keys = [] + vals = [] + for key, values in dictionary.items(): + keys.append([key] * len(values)) + vals.append(values) + keys = np.concatenate(keys) + vals = np.concatenate(vals) + + return {k: v for k, v in zip(vals, keys)} + + +def chunked(l: Sequence, n: int): + """ + Yield successive n-sized chunks from l. + NOTE: Use itertools.batched from python 3.12 + """ + if n < 1: + n = len(l) + it = iter(l) + while batch := tuple(islice(it, n)): + yield batch + + +def in2d(arr1: np.ndarray, arr2: np.ndarray) -> np.ndarray: + arr1_view = arr1.view(dtype="u8,u8").reshape(arr1.shape[0]) + arr2_view = arr2.view(dtype="u8,u8").reshape(arr2.shape[0]) + return np.in1d(arr1_view, arr2_view) diff --git a/pychunkedgraph/utils/redis.py b/pychunkedgraph/utils/redis.py new file mode 100644 index 000000000..420a849f1 --- /dev/null +++ b/pychunkedgraph/utils/redis.py @@ -0,0 +1,35 @@ +""" +redis helper funtions +""" + +import os +from collections import namedtuple + +import redis +from rq import Queue + +REDIS_HOST = os.environ.get( + "REDIS_SERVICE_HOST", + os.environ.get("REDIS_HOST", "localhost"), +) +REDIS_PORT = os.environ.get( + "REDIS_SERVICE_PORT", + os.environ.get("REDIS_PORT", "6379"), +) +REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD", "") +REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0" + +keys_fields = ("INGESTION_MANAGER",) +keys_defaults = ("pcg:imanager",) +Keys = namedtuple("keys", keys_fields, defaults=keys_defaults) + +keys = Keys() + + +def get_redis_connection(redis_url=REDIS_URL): + return redis.Redis.from_url(redis_url) + + +def get_rq_queue(queue): + connection = redis.Redis.from_url(REDIS_URL) + return Queue(queue, connection=connection) diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 000000000..9b1a97928 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,10 @@ +pylint +black +pyopenssl +jupyter +codecov +ipython +pytest +pytest-cov +pytest-mock +pytest-timeout \ No newline at end of file diff --git a/requirements.in b/requirements.in new file mode 100644 index 000000000..63e0b3472 --- /dev/null +++ b/requirements.in @@ -0,0 +1,33 @@ +click>=8.0 +protobuf>=4.22.0 +requests>=2.25.0 +grpcio>=1.36.1 +numpy +pandas +networkx>=2.1 +google-cloud-bigtable>=0.33.0 +google-cloud-datastore>=1.8 +flask +flask_cors +python-json-logger +redis +rq<2 +pyyaml +cachetools +werkzeug + +# PyPI only: +cloud-files>=4.21.1 +cloud-volume>=8.26.0 +multiwrapper +middle-auth-client>=3.11.0 +zmesh>=1.7.0 +fastremap>=1.14.0 +task-queue>=2.13.0 +messagingclient +dracopy>=1.3.0 +datastoreflex>=0.5.0 +zstandard==0.21.0 + +# Conda only - use requirements.yml (or install manually): +# graph-tool \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index a4df79be8..5a2f18adc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,19 +1,374 @@ -numpy -pandas -networkx==2.1 -google-cloud-bigtable>=0.33.0 -google-cloud-datastore>=1.8<=2.0dev -flask -flask_cors -codecov -multiwrapper -cloud-volume -python-json-logger -zstandard -redis -rq -middle-auth-client>=0.1.6 -dracopy==0.0.11 -zmesh -fastremap -pyopenssl \ No newline at end of file +# +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: +# +# pip-compile --output-file=requirements.txt requirements.in +# +attrs==23.1.0 + # via + # jsonschema + # referencing +blinker==1.6.2 + # via flask +boto3==1.28.52 + # via + # cloud-files + # cloud-volume + # task-queue +botocore==1.31.52 + # via + # boto3 + # s3transfer +brotli==1.1.0 + # via + # cloud-files + # urllib3 +cachetools==5.3.1 + # via + # -r requirements.in + # google-auth + # middle-auth-client +certifi==2023.7.22 + # via requests +chardet==5.2.0 + # via + # cloud-files + # cloud-volume +charset-normalizer==3.2.0 + # via requests +click==8.1.7 + # via + # -r requirements.in + # cloud-files + # compressed-segmentation + # compresso + # flask + # rq + # task-queue +cloud-files==4.21.1 + # via + # -r requirements.in + # cloud-volume + # datastoreflex +cloud-volume==8.26.0 + # via -r requirements.in +compressed-segmentation==2.2.1 + # via cloud-volume +compresso==3.2.1 + # via cloud-volume +crackle-codec==0.7.0 + # via cloud-volume +crc32c==2.3.post0 + # via cloud-files +datastoreflex==0.5.0 + # via -r requirements.in +deflate==0.4.0 + # via cloud-files +dill==0.3.7 + # via + # multiprocess + # pathos +dracopy==1.3.0 + # via + # -r requirements.in + # cloud-volume +fasteners==0.19 + # via cloud-files +fastremap==1.14.0 + # via + # -r requirements.in + # cloud-volume + # crackle-codec +flask==2.3.3 + # via + # -r requirements.in + # flask-cors + # middle-auth-client +flask-cors==4.0.0 + # via -r requirements.in +fpzip==1.2.2 + # via cloud-volume +furl==2.1.3 + # via middle-auth-client +gevent==23.9.1 + # via + # cloud-files + # cloud-volume + # task-queue +google-api-core[grpc]==2.11.1 + # via + # google-api-core + # google-cloud-bigtable + # google-cloud-core + # google-cloud-datastore + # google-cloud-pubsub + # google-cloud-storage +google-auth==2.23.0 + # via + # cloud-files + # cloud-volume + # google-api-core + # google-cloud-core + # google-cloud-storage + # task-queue +google-cloud-bigtable==2.21.0 + # via -r requirements.in +google-cloud-core==2.3.3 + # via + # cloud-files + # cloud-volume + # google-cloud-bigtable + # google-cloud-datastore + # google-cloud-storage + # task-queue +google-cloud-datastore==2.18.0 + # via + # -r requirements.in + # datastoreflex +google-cloud-pubsub==2.18.4 + # via messagingclient +google-cloud-storage==2.11.0 + # via + # cloud-files + # cloud-volume +google-crc32c==1.5.0 + # via + # cloud-files + # google-resumable-media +google-resumable-media==2.6.0 + # via google-cloud-storage +googleapis-common-protos[grpc]==1.60.0 + # via + # google-api-core + # grpc-google-iam-v1 + # grpcio-status +greenlet==3.0.0rc3 + # via gevent +grpc-google-iam-v1==0.12.6 + # via + # google-cloud-bigtable + # google-cloud-pubsub +grpcio==1.58.0 + # via + # -r requirements.in + # google-api-core + # google-cloud-pubsub + # googleapis-common-protos + # grpc-google-iam-v1 + # grpcio-status +grpcio-status==1.58.0 + # via + # google-api-core + # google-cloud-pubsub +idna==3.4 + # via requests +inflection==0.5.1 + # via python-jsonschema-objects +iniconfig==2.0.0 + # via pytest +itsdangerous==2.1.2 + # via flask +jinja2==3.1.3 + # via flask +jmespath==1.0.1 + # via + # boto3 + # botocore +json5==0.9.14 + # via cloud-volume +jsonschema==4.19.1 + # via + # cloud-volume + # python-jsonschema-objects +jsonschema-specifications==2023.7.1 + # via jsonschema +markdown==3.4.4 + # via python-jsonschema-objects +markupsafe==2.1.3 + # via + # jinja2 + # werkzeug +messagingclient==0.1.3 + # via -r requirements.in +middle-auth-client==3.16.1 + # via -r requirements.in +multiprocess==0.70.15 + # via pathos +multiwrapper==0.1.1 + # via -r requirements.in +networkx==3.1 + # via + # -r requirements.in + # cloud-volume +numpy==1.26.0 + # via + # -r requirements.in + # cloud-volume + # compressed-segmentation + # compresso + # crackle-codec + # fastremap + # fpzip + # messagingclient + # multiwrapper + # pandas + # pyspng-seunglab + # simplejpeg + # task-queue + # zfpc + # zmesh +orderedmultidict==1.0.1 + # via furl +orjson==3.9.7 + # via + # cloud-files + # task-queue +packaging==23.1 + # via pytest +pandas==2.1.1 + # via -r requirements.in +pathos==0.3.1 + # via + # cloud-files + # cloud-volume + # task-queue +pbr==5.11.1 + # via task-queue +pillow==10.0.1 + # via cloud-volume +pluggy==1.3.0 + # via pytest +posix-ipc==1.1.1 + # via cloud-volume +pox==0.3.3 + # via pathos +ppft==1.7.6.7 + # via pathos +proto-plus==1.22.3 + # via + # google-cloud-bigtable + # google-cloud-datastore + # google-cloud-pubsub +protobuf==4.24.3 + # via + # -r requirements.in + # cloud-files + # cloud-volume + # google-api-core + # google-cloud-bigtable + # google-cloud-datastore + # google-cloud-pubsub + # googleapis-common-protos + # grpc-google-iam-v1 + # grpcio-status + # proto-plus +psutil==5.9.5 + # via cloud-volume +pyasn1==0.5.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth +pybind11==2.11.1 + # via crackle-codec +pysimdjson==5.0.2 + # via cloud-volume +pyspng-seunglab==1.1.0 + # via cloud-volume +pytest==7.4.2 + # via compressed-segmentation +python-dateutil==2.8.2 + # via + # botocore + # cloud-volume + # pandas +python-json-logger==2.0.7 + # via -r requirements.in +python-jsonschema-objects==0.5.0 + # via cloud-volume +pytz==2023.3.post1 + # via pandas +pyyaml==6.0.1 + # via -r requirements.in +redis==5.0.0 + # via + # -r requirements.in + # rq +referencing==0.30.2 + # via + # jsonschema + # jsonschema-specifications +requests==2.31.0 + # via + # -r requirements.in + # cloud-files + # cloud-volume + # google-api-core + # google-cloud-storage + # middle-auth-client + # task-queue +rpds-py==0.10.3 + # via + # jsonschema + # referencing +rq==1.15.1 + # via -r requirements.in +rsa==4.9 + # via + # cloud-files + # google-auth +s3transfer==0.6.2 + # via boto3 +simplejpeg==1.7.2 + # via cloud-volume +six==1.16.0 + # via + # cloud-files + # cloud-volume + # furl + # orderedmultidict + # python-dateutil + # python-jsonschema-objects +task-queue==2.13.0 + # via -r requirements.in +tenacity==8.2.3 + # via + # cloud-files + # cloud-volume + # task-queue +tqdm==4.66.1 + # via + # cloud-files + # cloud-volume + # task-queue +tzdata==2023.3 + # via pandas +urllib3[brotli]==1.26.16 + # via + # botocore + # cloud-files + # cloud-volume + # google-auth + # requests +werkzeug==2.3.8 + # via + # -r requirements.in + # flask +zfpc==0.1.2 + # via cloud-volume +zfpy==1.0.0 + # via zfpc +zmesh==1.7.0 + # via -r requirements.in +zope-event==5.0 + # via gevent +zope-interface==6.0 + # via gevent +zstandard==0.21.0 + # via + # -r requirements.in + # cloud-files + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/requirements.yml b/requirements.yml new file mode 100644 index 000000000..0bfa5b227 --- /dev/null +++ b/requirements.yml @@ -0,0 +1,13 @@ +name: pychunkedgraph +channels: + - conda-forge +dependencies: + - python==3.11.4 + - pip + - tox + - uwsgi==2.0.21 + - graph-tool-base==2.58 + - zstandard==0.19.0 # ugly hack to force PyPi install 0.21.0 + - pip: + - -r requirements.txt + - -r requirements-dev.txt \ No newline at end of file diff --git a/rq_workers/ingest_worker.py b/rq_workers/ingest_worker.py deleted file mode 100644 index dfe265e34..000000000 --- a/rq_workers/ingest_worker.py +++ /dev/null @@ -1,20 +0,0 @@ -import os -# This is for monitoring rq with supervisord -# For the flask app use a config class - -# env REDIS_SERVICE_HOST and REDIS_SERVICE_PORT are added by Kubernetes -REDIS_HOST = os.environ.get('REDIS_SERVICE_HOST') -REDIS_PORT = os.environ.get('REDIS_SERVICE_PORT') -REDIS_PASSWORD = os.environ.get('REDIS_PASSWORD') -REDIS_URL = f'redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0' - -# Queues to listen on -QUEUES = ['default', 'ingest-chunks'] - -# If you're using Sentry to collect your runtime exceptions, you can use this -# to configure RQ for it in a single step -# The 'sync+' prefix is required for raven: https://github.com/nvie/rq/issues/350#issuecomment-43592410 -# SENTRY_DSN = 'sync+http://public:secret@example.com/1' - -# If you want custom worker name -# NAME = 'worker-1024' diff --git a/rq_workers/mesh_worker.py b/rq_workers/mesh_worker.py deleted file mode 100644 index 2a0ececda..000000000 --- a/rq_workers/mesh_worker.py +++ /dev/null @@ -1,23 +0,0 @@ -import os -# This is for monitoring rq with supervisord -# For the flask app use a config class - -# env REDIS_SERVICE_HOST and REDIS_SERVICE_PORT are added by Kubernetes -REDIS_HOST = os.environ.get('REDIS_SERVICE_HOST') -REDIS_PORT = os.environ.get('REDIS_SERVICE_PORT') -REDIS_PASSWORD = os.environ.get('REDIS_PASSWORD') -if REDIS_PASSWORD is None: - REDIS_URL = f'redis://@{REDIS_HOST}:{REDIS_PORT}/0' -else: - REDIS_URL = f'redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0' - -# Queues to listen on -QUEUES = ['default', 'mesh-chunks'] - -# If you're using Sentry to collect your runtime exceptions, you can use this -# to configure RQ for it in a single step -# The 'sync+' prefix is required for raven: https://github.com/nvie/rq/issues/350#issuecomment-43592410 -# SENTRY_DSN = 'sync+http://public:secret@example.com/1' - -# If you want custom worker name -# NAME = 'worker-1024' diff --git a/rq_workers/test_worker.py b/rq_workers/test_worker.py deleted file mode 100644 index 256957b04..000000000 --- a/rq_workers/test_worker.py +++ /dev/null @@ -1,20 +0,0 @@ -import os -# This is for monitoring rq with supervisord -# For the flask app use a config class - -# env REDIS_SERVICE_HOST and REDIS_SERVICE_PORT are added by Kubernetes -REDIS_HOST = os.environ.get('REDIS_SERVICE_HOST') -REDIS_PORT = os.environ.get('REDIS_SERVICE_PORT') -REDIS_PASSWORD = os.environ.get('REDIS_PASSWORD') -REDIS_URL = f'redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0' - -# Queues to listen on -QUEUES = ['test'] - -# If you're using Sentry to collect your runtime exceptions, you can use this -# to configure RQ for it in a single step -# The 'sync+' prefix is required for raven: https://github.com/nvie/rq/issues/350#issuecomment-43592410 -# SENTRY_DSN = 'sync+http://public:secret@example.com/1' - -# If you want custom worker name -# NAME = 'worker-1024' diff --git a/run_dev.py b/run_dev.py index aeec6b1e7..24906b159 100644 --- a/run_dev.py +++ b/run_dev.py @@ -1,27 +1,23 @@ import sys from werkzeug.serving import WSGIRequestHandler -import os from pychunkedgraph.app import create_app app = create_app() -if __name__ == '__main__': - - assert len(sys.argv) == 2 - HOME = os.path.expanduser("~") +if __name__ == "__main__": + assert len(sys.argv) >= 2 port = int(sys.argv[1]) - - # Set HTTP protocol WSGIRequestHandler.protocol_version = "HTTP/1.1" - # WSGIRequestHandler.protocol_version = "HTTP/2.0" - - print("Table: %s; Port: %d" % - (app.config['CHUNKGRAPH_TABLE_ID'], port)) - app.run(host='0.0.0.0', + if len(sys.argv) == 2: + app.run( + host="0.0.0.0", port=port, debug=True, threaded=True, - ssl_context='adhoc') \ No newline at end of file + ssl_context="adhoc", + ) + else: + app.run(host="0.0.0.0", port=port, debug=True, threaded=True) diff --git a/run_dev_cli.py b/run_dev_cli.py deleted file mode 100644 index 328179dbc..000000000 --- a/run_dev_cli.py +++ /dev/null @@ -1,10 +0,0 @@ -from flask.cli import FlaskGroup -from pychunkedgraph.examples import create_example_app - - -app = create_example_app() -cli = FlaskGroup(create_app=create_example_app) - - -if __name__ == '__main__': - cli() \ No newline at end of file diff --git a/setup.py b/setup.py index 8891b37c1..e71fcab1b 100644 --- a/setup.py +++ b/setup.py @@ -5,15 +5,15 @@ here = os.path.abspath(os.path.dirname(__file__)) + def read(*parts): - with codecs.open(os.path.join(here, *parts), 'r') as fp: + with codecs.open(os.path.join(here, *parts), "r") as fp: return fp.read() def find_version(*file_paths): version_file = read(*file_paths) - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - version_file, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) if version_match: return version_match.group(1) raise RuntimeError("Unable to find version string.") @@ -22,7 +22,7 @@ def find_version(*file_paths): with open("README.md", "r") as fh: long_description = fh.read() -with open('requirements.txt', 'r') as f: +with open("requirements.txt", "r") as f: required = f.read().splitlines() dependency_links = [] @@ -40,7 +40,7 @@ def find_version(*file_paths): setup( name="PyChunkedGraph", - version=find_version('pychunkedgraph', '__init__.py'), + version=find_version("pychunkedgraph", "__init__.py"), author="Sven Dorkenwald", author_email="svenmd@princeton.edu", description="Proofreading backend for Neuroglancer", @@ -54,5 +54,5 @@ def find_version(*file_paths): "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: POSIX :: Linux", - ] + ], ) diff --git a/tox.ini b/tox.ini index cfa25c129..5398564e6 100644 --- a/tox.ini +++ b/tox.ini @@ -1,14 +1,13 @@ [tox] -envlist = py36dev +envlist = py311 +requires = tox-conda [testenv] setenv = HOME = {env:HOME} -usedevelop = true -sitepackages = true + deps = pytest pytest-cov pytest-mock pytest-timeout - numpy -commands = python -m pytest {posargs} ./pychunkedgraph/tests/ -install_command = {toxinidir}/tox_install_command.sh {opts} {packages} +conda_env = requirements.yml +commands = python -m pytest -v {posargs} ./pychunkedgraph/tests/ diff --git a/tox_install_command.sh b/tox_install_command.sh deleted file mode 100755 index 4fa3935a9..000000000 --- a/tox_install_command.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash - -pip install pip==18.1 -pip install --process-dependency-links "$@" \ No newline at end of file diff --git a/tracker.py b/tracker.py new file mode 100644 index 000000000..d2ae63cb3 --- /dev/null +++ b/tracker.py @@ -0,0 +1,22 @@ +import sys +from rq import Connection, Worker + +# Preload libraries from pychunkedgraph.ingest.cluster +from typing import Sequence, Tuple + +import numpy as np + +from pychunkedgraph.ingest.utils import chunk_id_str +from pychunkedgraph.ingest.manager import IngestionManager +from pychunkedgraph.ingest.common import get_atomic_chunk_data +from pychunkedgraph.ingest.ran_agglomeration import get_active_edges +from pychunkedgraph.ingest.create.atomic_layer import add_atomic_edges +from pychunkedgraph.ingest.create.abstract_layers import add_layer +from pychunkedgraph.graph.meta import ChunkedGraphMeta +from pychunkedgraph.graph.chunks.hierarchy import get_children_chunk_coords +from pychunkedgraph.utils.redis import keys as r_keys +from pychunkedgraph.utils.redis import get_redis_connection + +qs = sys.argv[1:] +w = Worker(qs, connection=get_redis_connection()) +w.work() \ No newline at end of file diff --git a/uwsgi.ini b/uwsgi.ini index 1917bbae6..776e2ff00 100644 --- a/uwsgi.ini +++ b/uwsgi.ini @@ -13,10 +13,14 @@ gid = nginx env = HOME=/home/nginx +# Python venv +if-env = VIRTUAL_ENV +virtualenv = %(_) +endif = ### Worker scaling # maximum number of workers -processes = 64 +processes = 16 # https://uwsgi-docs.readthedocs.io/en/latest/Cheaper.html#busyness-cheaper-algorithm cheaper-algo = busyness @@ -48,7 +52,7 @@ cheaper-busyness-min = 20 listen = 4096 # Max request header size -buffer-size = 4096 +buffer-size = 65535 # Don't spawn new workers if total memory over 6 GiB cheaper-rss-limit-soft = 6442450944 @@ -104,3 +108,11 @@ log-route = unknown ^(?:(?!^{address space usage|\[warn\]|^{.*"source".*}$).)*$ log-encoder = json:unknown {"source":"unknown","time":"${strftime:%Y-%m-%dT%H:%M:%S.000Z}","severity":"error","message":"${msg}"} log-encoder = nl:unknown + +log-4xx = true +log-5xx = true +disable-logging = true + +stats = 127.0.0.1:9191 +stats-http = 127.0.0.1:9192 +stats-interval = 5 \ No newline at end of file diff --git a/rq_workers/__init__.py b/workers/__init__.py similarity index 100% rename from rq_workers/__init__.py rename to workers/__init__.py diff --git a/workers/mesh_worker.py b/workers/mesh_worker.py new file mode 100644 index 000000000..b8f1e0024 --- /dev/null +++ b/workers/mesh_worker.py @@ -0,0 +1,76 @@ +# pylint: disable=invalid-name, missing-docstring, too-many-locals, logging-fstring-interpolation + +import gc +import pickle +import logging +from os import path +from os import getenv + +import numpy as np +from messagingclient import MessagingClient + +from pychunkedgraph.graph import ChunkedGraph +from pychunkedgraph.graph.utils import basetypes +from pychunkedgraph.meshing import meshgen + + +PCG_CACHE = {} + + +def callback(payload): + data = pickle.loads(payload.data) + op_id = int(data["operation_id"]) + l2ids = np.array(data["new_lvl2_ids"], dtype=basetypes.NODE_ID) + table_id = payload.attributes["table_id"] + remesh = payload.attributes["remesh"] + + if remesh == "false": + return + + try: + cg = PCG_CACHE[table_id] + except KeyError: + cg = ChunkedGraph(graph_id=table_id) + PCG_CACHE[table_id] = cg + + INFO_HIGH = 25 + logging.basicConfig( + level=INFO_HIGH, + format="%(asctime)s %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + + try: + mesh_meta = cg.meta.custom_data["mesh"] + mesh_dir = mesh_meta["dir"] + layer = mesh_meta["max_layer"] + mip = mesh_meta["mip"] + err = mesh_meta["max_error"] + cv_unsharded_mesh_dir = mesh_meta.get("dynamic_mesh_dir", "dynamic") + except KeyError: + logging.warning(f"No metadata found for {cg.graph_id}; ignoring...") + return + + mesh_path = path.join( + cg.meta.data_source.WATERSHED, mesh_dir, cv_unsharded_mesh_dir + ) + + + logging.log(INFO_HIGH, f"remeshing {l2ids}; graph {table_id} operation {op_id}.") + meshgen.remeshing( + cg, + l2ids, + stop_layer=layer, + mip=mip, + max_err=err, + cv_sharded_mesh_dir=mesh_dir, + cv_unsharded_mesh_path=mesh_path, + ) + logging.log(INFO_HIGH, f"remeshing complete; graph {table_id} operation {op_id}.") + gc.collect() + + +c = MessagingClient() +remesh_queue = getenv("PYCHUNKEDGRAPH_REMESH_QUEUE") +assert remesh_queue is not None, "env PYCHUNKEDGRAPH_REMESH_QUEUE not specified." +c.consume(remesh_queue, callback)