diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..9eededb9 --- /dev/null +++ b/.dockerignore @@ -0,0 +1 @@ +transformers_cache \ No newline at end of file diff --git a/Dockerfile.2 b/Dockerfile.2 new file mode 100644 index 00000000..d01ed71b --- /dev/null +++ b/Dockerfile.2 @@ -0,0 +1,348 @@ +ARG BASE_IMAGE + +# ------------------------ +# Target: dev +# ------------------------ +FROM $BASE_IMAGE as dev + +ARG TOOLKIT_USER_ID=13011 +ARG TOOLKIT_GROUP_ID=13011 + +RUN apt-get update \ + # Required to save git hashes + && apt-get install -y -q git curl unzip make gettext \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app +ENV XDG_DATA_HOME=/app/.local/share \ + XDG_CACHE_HOME=/app/.cache \ + XDG_BIN_HOME=/app/.local/bin \ + XDG_CONFIG_HOME=/app/.config +RUN mkdir -p $XDG_DATA_HOME \ + && mkdir -p $XDG_CACHE_HOME \ + && mkdir -p $XDG_BIN_HOME \ + && mkdir -p $XDG_CONFIG_HOME \ + && chown -R $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID /app + +# Install C++ toolchain, Facebook thrift, and dependencies +RUN curl https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - \ + && apt-get update \ + && apt-get install -y --no-install-recommends software-properties-common \ + && apt-add-repository "deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-9 main" \ + && apt-get update \ + && apt-get install -y --no-install-recommends \ + binfmt-support libllvm9 llvm-9 llvm-9-dev llvm-9-runtime llvm-9-tools python-chardet python-pygments python-yaml \ + g++ \ + cmake \ + libboost-all-dev \ + libevent-dev \ + libdouble-conversion-dev \ + libgoogle-glog-dev \ + libgflags-dev \ + libiberty-dev \ + liblz4-dev \ + liblzma-dev \ + libsnappy-dev \ + make \ + zlib1g-dev \ + binutils-dev \ + libjemalloc-dev \ + libssl-dev \ + pkg-config \ + libunwind-dev \ + libmysqlclient-dev \ + bison \ + flex \ + libsodium-dev \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID third_party/zstd /app/third_party/zstd/ +RUN cd /app/third_party/zstd \ + && make -j4 \ + && make install \ + && make clean +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID third_party/fmt /app/third_party/fmt/ +RUN cd /app/third_party/fmt/ \ + && mkdir _build \ + && cd _build \ + && cmake -DBUILD_SHARED_LIBS=ON -DBUILD_EXAMPLES=off -DBUILD_TESTS=off ../. \ + && make -j4 \ + && make install \ + && cd .. \ + && rm -rf _build +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID third_party/folly /app/third_party/folly/ +RUN pip install cython \ + && cd /app/third_party/folly \ + && mkdir _build \ + && cd _build \ + && cmake -DBUILD_SHARED_LIBS=ON -DPYTHON_EXTENSIONS=ON -DBUILD_EXAMPLES=off -DBUILD_TESTS=off ../. \ + && make -j4 \ + && make install \ + && cp folly/cybld/dist/folly-0.0.1-cp37-cp37m-linux_x86_64.whl /app/ \ + && chown $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID /app/folly-0.0.1-cp37-cp37m-linux_x86_64.whl \ + && pip install /app/folly-0.0.1-cp37-cp37m-linux_x86_64.whl \ + && cd .. \ + && rm -rf _build +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID third_party/rsocket-cpp /app/third_party/rsocket-cpp/ +RUN cd /app/third_party/rsocket-cpp \ + && mkdir _build \ + && cd _build \ + && cmake -DBUILD_SHARED_LIBS=ON -DBUILD_EXAMPLES=off -DBUILD_TESTS=off ../. \ + && make -j4 \ + && make install \ + && cd .. \ + && rm -rf _build +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID third_party/fizz /app/third_party/fizz/ +RUN cd /app/third_party/fizz \ + && mkdir _build \ + && cd _build \ + && cmake -DBUILD_SHARED_LIBS=ON -DBUILD_EXAMPLES=off -DBUILD_TESTS=off ../fizz \ + && make -j4 \ + && make install \ + && cd .. \ + && rm -rf _build +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID third_party/wangle /app/third_party/wangle/ +RUN cd /app/third_party/wangle \ + && mkdir _build \ + && cd _build \ + && cmake -DBUILD_SHARED_LIBS=ON -DBUILD_EXAMPLES=off -DBUILD_TESTS=off ../wangle \ + && make -j4 \ + && make install \ + && cd .. \ + && rm -rf _build +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID third_party/fbthrift /app/third_party/fbthrift/ +RUN cd /app/third_party/fbthrift \ + && mkdir _build \ + && cd _build \ + && cmake \ + -DBUILD_SHARED_LIBS=ON \ + -DPYTHON_INCLUDE_DIR=$(python -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())") \ + -DPYTHON_LIBRARY=$(python -c "import distutils.sysconfig as sysconfig; import os; print(os.path.join(sysconfig.get_config_var('LIBDIR'), sysconfig.get_config_var('LDLIBRARY')))") \ + -Dthriftpy3=ON \ + ../. \ + && make -j4 \ + && DESTDIR=/ make install \ + && cp thrift/lib/py3/cybld/dist/thrift-0.0.1-cp37-cp37m-linux_x86_64.whl /app/ \ + && chown $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID /app/thrift-0.0.1-cp37-cp37m-linux_x86_64.whl \ + && pip install /app/thrift-0.0.1-cp37-cp37m-linux_x86_64.whl \ + && cd .. \ + && rm -rf _build + +# Install Rust toolchain +ENV RUSTUP_HOME=/app/.local/rustup \ + CARGO_HOME=/app/.local/cargo \ + PATH=/app/.local/cargo/bin:$PATH +RUN set -eux; \ + apt-get update; \ + apt-get install -y --no-install-recommends \ + ca-certificates \ + gcc \ + libc6-dev \ + wget \ + ; \ + \ + url="https://static.rust-lang.org/rustup/dist/x86_64-unknown-linux-gnu/rustup-init"; \ + wget "$url"; \ + chmod +x rustup-init; \ + ./rustup-init -y --no-modify-path --default-toolchain nightly-2021-06-01; \ + rm rustup-init; \ + chmod -R a+w $RUSTUP_HOME $CARGO_HOME; \ + rustup --version; \ + cargo --version; \ + rustc --version; \ + rm -rf /var/lib/apt/lists/*; \ + chown -R $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID /app/.local/cargo; \ + chown -R $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID /app/.local/rustup; + +# Install Haskell toolchain +ENV BOOTSTRAP_HASKELL_NONINTERACTIVE=yes \ + BOOTSTRAP_HASKELL_NO_UPGRADE=yes \ + GHCUP_USE_XDG_DIRS=yes \ + GHCUP_INSTALL_BASE_PREFIX=/app \ + CABAL_DIR=/app/.cabal \ + PATH=/app/.cabal/bin:/app/.local/bin:$PATH +RUN buildDeps=" \ + curl \ + "; \ + deps=" \ + libtinfo-dev \ + libgmp3-dev \ + "; \ + apt-get update \ + && apt-get install -y --no-install-recommends $buildDeps $deps \ + && curl --proto '=https' --tlsv1.2 -sSf https://get-ghcup.haskell.org | sh \ + && ghcup install ghc \ + && ghcup install cabal \ + && cabal update \ + && apt-get install -y --no-install-recommends git \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* \ + && git clone https://github.com/haskell/cabal.git \ + && cd cabal \ + && git checkout f5f8d933db229d30e6fc558f5335f0a4e85d7d44 \ + && sed -i 's/3.5.0.0/3.6.0.0/' */*.cabal \ + && cabal install cabal-install/ \ + --allow-newer=Cabal-QuickCheck:Cabal \ + --allow-newer=Cabal-described:Cabal \ + --allow-newer=Cabal-tree-diff:Cabal \ + --allow-newer=cabal-install:Cabal \ + --allow-newer=cabal-install-solver:Cabal \ + && cd .. \ + && rm -rf cabal/ \ + && rm -rf /app/.cabal/packages/* \ + && rm -rf /app/.cabal/logs/* \ + && rm -rf /app/.cache/ghcup \ + && chown -R $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID /app/.cabal \ + && chown -R $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID /app/.local/bin \ + && chown -R $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID /app/.local/share/ghcup + +# Build Facebook hsthrift +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID third_party/hsthrift /app/third_party/hsthrift/ +RUN cd /app/third_party/hsthrift \ + && make thrift-cpp \ + && cabal update \ + && cabal build exe:thrift-compiler \ + && make thrift-hs \ + && cabal install exe:thrift-compiler \ + && cabal clean \ + && rm -rf /app/.cabal/packages/* \ + && rm -rf /app/.cabal/logs/* \ + && chown -h $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID /app/.cabal/bin/thrift-compiler \ + && find /app/.cabal/store/ghc-8.10.*/ -maxdepth 2 -type d -group root -exec chown -R $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID {} \; \ + && find . -group root -exec chown $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID {} \; + +# Install misc utilities and add toolkit user +ENV LANG=en_US.UTF-8 +RUN apt update && \ + apt install -y \ + zsh fish gnupg lsb-release \ + ca-certificates supervisor openssh-server bash ssh tmux jq \ + curl wget vim procps htop locales nano man net-tools iputils-ping \ + openssl libicu[0-9][0-9] libkrb5-3 zlib1g gnome-keyring libsecret-1-0 desktop-file-utils x11-utils && \ + curl -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg && \ + echo \ + "deb [arch=amd64 signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/ubuntu \ + $(lsb_release -cs) stable" | tee /etc/apt/sources.list.d/docker.list > /dev/null && \ + apt-get update && \ + apt-get install -y docker-ce docker-ce-cli containerd.io && \ + sed -i "s/# en_US.UTF-8/en_US.UTF-8/" /etc/locale.gen && \ + locale-gen && \ + useradd -m -u $TOOLKIT_USER_ID -s /bin/bash --non-unique toolkit && \ + passwd -d toolkit && \ + useradd -m -u $TOOLKIT_USER_ID -s /bin/bash --non-unique console && \ + passwd -d console && \ + useradd -m -u $TOOLKIT_USER_ID -s /bin/bash --non-unique _toolchain && \ + passwd -d _toolchain && \ + useradd -m -u $TOOLKIT_USER_ID -s /bin/bash --non-unique coder && \ + passwd -d coder && \ + chown -R $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID /run /etc/shadow /etc/profile && \ + apt autoremove --purge && apt-get clean && \ + rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* && \ + echo ssh >> /etc/securetty && \ + rm -f /etc/legal /etc/motd + +# Build Huggingface tokenizers Rust libraries +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID third_party/tokenizers /app/third_party/tokenizers/ +RUN cd /app/third_party/tokenizers \ + && rustup --version \ + && cargo --version \ + && rustc --version \ + && cargo build --release \ + && cp target/release/libtokenizers_haskell.so /usr/lib/ \ + && rm -rf target \ + && find /app/.local/cargo -group root -exec chown $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID {} \; +ENV TOKENIZERS_PARALLELISM=false + +# Install Python toolchain +ENV PYTHONPATH=/app +RUN pip install --no-cache-dir pre-commit "poetry==1.1.7" +# Disable virtualenv creation to install our dependencies system-wide. +RUN poetry config virtualenvs.create false +# Config file is not readable by other users by default, which prevents +# it from being read on Drone, therefore make it readable. +RUN chmod go+r $XDG_CONFIG_HOME/pypoetry/config.toml +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID pyproject.toml poetry.lock /app/ +RUN poetry install --extras "deepspeed" \ + && pip install /app/folly-0.0.1-cp37-cp37m-linux_x86_64.whl \ + && pip install /app/thrift-0.0.1-cp37-cp37m-linux_x86_64.whl \ + && rm -rf $XDG_CACHE_HOME/pip \ + && rm -rf $XDG_CACHE_HOME/pypoetry \ + && chown -R $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID $XDG_CONFIG_HOME/pypoetry + +# Unfortunately, nltk doesn't look in XDG_DATA_HOME, so therefore /usr/local/share +RUN python -m nltk.downloader -d /usr/local/share/nltk_data punkt stopwords + +# ------------------------ +# Target: train +# ------------------------ +FROM dev as train + +ARG TOOLKIT_USER_ID=13011 +ARG TOOLKIT_GROUP_ID=13011 + +# Misc environment variables +ENV HF_HOME=/transformers_cache + +# Copy Seq-to-seq code +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID ./seq2seq /app/seq2seq/ +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID ./tests /app/tests/ +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID ./third_party/spider /app/third_party/spider/ +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID ./third_party/test_suite /app/third_party/test_suite/ +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID ./configs /app/configs/ + +# ------------------------ +# Target: eval +# ------------------------ +FROM dev as eval + +ARG TOOLKIT_USER_ID=13011 +ARG TOOLKIT_GROUP_ID=13011 + +# Add thrift file +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID picard.thrift /app/ + +# Build Cython code +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID gen-cpp2 /app/gen-cpp2/ +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID gen-py3 /app/gen-py3/ +RUN thrift1 --gen mstch_cpp2 picard.thrift \ + && thrift1 --gen mstch_py3 picard.thrift \ + && cd gen-py3 && python setup.py build_ext --inplace \ + && chown -R $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID /app/gen-py3 /app/gen-cpp2 +ENV PYTHONPATH=$PYTHONPATH:/app/gen-py3 \ + LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/app/gen-py3/picard + +# Build and install Picard +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID cabal.project fb-util-cabal.patch /app/ +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID gen-hs /app/gen-hs/ +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID picard /app/picard/ +RUN cabal update \ + && cd third_party/hsthrift \ + && make THRIFT_COMPILE=thrift-compiler thrift-cpp thrift-hs \ + && cd ../.. \ + && thrift-compiler --hs --use-hash-map --use-hash-set --gen-prefix gen-hs -o . picard.thrift \ + && patch -p 1 -d third_party/hsthrift < ./fb-util-cabal.patch \ + && cabal install --overwrite-policy=always --install-method=copy exe:picard \ + && chown $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID /app/.cabal/bin/picard \ + && cabal clean \ + && rm -rf /app/third_party/hsthrift/compiler/tests \ + && rm -rf /app/.cabal/packages/* \ + && rm -rf /app/.cabal/logs/* \ + && chown -R $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID /app/picard/ \ + && chown -R $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID /app/gen-hs/ \ + && find /app/.cabal/store/ghc-8.10.*/ -maxdepth 2 -type d -group root -exec chown -R $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID {} \; \ + && find /app/.cabal/store/ghc-8.10.*/ -maxdepth 2 -type f -group root -exec chown $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID {} \; + +# Misc environment variables +ENV HF_HOME=/transformers_cache + +# Copy Seq-to-seq code +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID ./seq2seq /app/seq2seq/ +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID ./tests /app/tests/ +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID ./third_party/spider /app/third_party/spider/ +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID ./third_party/test_suite /app/third_party/test_suite/ +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID ./configs /app/configs/ + +# Test Picard +RUN python /app/tests/test_picard_client.py \ + && rm -rf /app/seq2seq/__pycache__ \ + && rm -rf /app/gen-py3/picard/__pycache__ diff --git a/Dockerfile.eval b/Dockerfile.eval new file mode 100644 index 00000000..e5712aa5 --- /dev/null +++ b/Dockerfile.eval @@ -0,0 +1,17 @@ +ARG BASE_IMAGE +FROM $BASE_IMAGE as dev + +ARG TOOLKIT_USER_ID=13011 +ARG TOOLKIT_GROUP_ID=13011 + +RUN pip install poetry && poetry config virtualenvs.create false +RUN poetry update +RUN pip install transformers datasets pynvml deepspeed tenacity rapidfuzz==2.0.5 nltk==3.7 \ + sqlparse==0.4.2 pyarrow==7.0.0 loguru accelerate + +ENV XDG_DATA_HOME=/app/.local/share \ + XDG_CACHE_HOME=/app/.cache \ + XDG_BIN_HOME=/app/.local/bin \ + XDG_CONFIG_HOME=/app/.config + +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID ./seq2seq /app/seq2seq/ \ No newline at end of file diff --git a/Dockerfile.train b/Dockerfile.train new file mode 100644 index 00000000..1c645c37 --- /dev/null +++ b/Dockerfile.train @@ -0,0 +1,59 @@ +# FROM tscholak/text-to-sql-train:6a252386bed6d4233f0f13f4562d8ae8608e7445 +# with this Dockerfile, train started fine, but unkown error due to deepspeed cpu offload + +ARG BASE_IMAGE + +# ------------------------ +# Target: dev +# ------------------------ +FROM $BASE_IMAGE as dev + +ARG TOOLKIT_USER_ID=13011 +ARG TOOLKIT_GROUP_ID=13011 + +RUN \ + # Required to save git hashes + apt-get install -y -q git curl unzip make gettext + +ENV XDG_DATA_HOME=/app/.local/share \ + XDG_CACHE_HOME=/app/.cache \ + XDG_BIN_HOME=/app/.local/bin \ + XDG_CONFIG_HOME=/app/.config +RUN mkdir -p $XDG_DATA_HOME \ + && mkdir -p $XDG_CACHE_HOME \ + && mkdir -p $XDG_BIN_HOME \ + && mkdir -p $XDG_CONFIG_HOME \ + && chown -R $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID /app + +WORKDIR /app + +# Misc environment variables +ENV HF_HOME=/transformers_cache + +# datasets==1.18.4 +# copy poetry files +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID pyproject.toml poetry.lock /app/ + +RUN pip install poetry && poetry config virtualenvs.create false +RUN poetry update +RUN pip install transformers datasets pynvml deepspeed tenacity rapidfuzz==2.0.5 nltk==3.7 \ + sqlparse==0.4.2 pyarrow==7.0.0 loguru accelerate +# RUN poetry install + + +RUN git clone https://github.com/microsoft/DeepSpeed/ && \ + cd DeepSpeed && git checkout v0.8.1 && \ + rm -rf build && \ + TORCH_CUDA_ARCH_LIST="8.6" DS_BUILD_CPU_ADAM=1 DS_BUILD_UTILS=1 pip install . \ + --global-option="build_ext" --global-option="-j8" --no-cache -v \ + --disable-pip-version-check + +# Copy Seq-to-seq code +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID ./seq2seq /app/seq2seq/ +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID ./tests /app/tests/ +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID ./third_party/spider /app/third_party/spider/ +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID ./third_party/test_suite /app/third_party/test_suite/ +COPY --chown=$TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID ./configs /app/configs/ + +# change permission for /app/.cache +RUN mkdir -p /app/.cache && chown $TOOLKIT_USER_ID:$TOOLKIT_GROUP_ID /app/.cache diff --git a/Makefile b/Makefile index 0857a93f..8188e4af 100644 --- a/Makefile +++ b/Makefile @@ -64,59 +64,65 @@ pull-dev-image: .PHONY: build-train-image build-train-image: - ssh-add - docker buildx build \ - --builder $(BUILDKIT_BUILDER) \ - --ssh default=$(SSH_AUTH_SOCK) \ - -f Dockerfile \ - --tag tscholak/$(TRAIN_IMAGE_NAME):$(GIT_HEAD_REF) \ - --tag tscholak/$(TRAIN_IMAGE_NAME):cache \ - --build-arg BASE_IMAGE=$(BASE_IMAGE) \ - --target train \ - --cache-from type=registry,ref=tscholak/$(TRAIN_IMAGE_NAME):cache \ - --cache-to type=inline \ - --push \ - git@github.com:ElementAI/picard#$(GIT_HEAD_REF) + docker build . -f Dockerfile.train -t picard --build-arg BASE_IMAGE=tscholak/text-to-sql-train:6a252386bed6d4233f0f13f4562d8ae8608e7445 + # docker build . -f Dockerfile.train -t picard --build-arg BASE_IMAGE=pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel + # docker buildx build + # --builder $(BUILDKIT_BUILDER) \ + # --ssh default=$(SSH_AUTH_SOCK) \ + # -f Dockerfile.train \ + # --tag picard \ + # --build-arg BASE_IMAGE=$(BASE_IMAGE) \ + # --target train \ + # --cache-from type=registry,ref=tscholak/$(TRAIN_IMAGE_NAME):cache \ + # --cache-to type=inline .PHONY: pull-train-image pull-train-image: - docker pull tscholak/$(TRAIN_IMAGE_NAME):$(GIT_HEAD_REF) + docker pull tscholak/text-to-sql-train:6a252386bed6d4233f0f13f4562d8ae8608e7445 + docker build . -t picard -f Dockerfile.train .PHONY: build-eval-image -build-eval-image: - ssh-add - docker buildx build \ - --builder $(BUILDKIT_BUILDER) \ - --ssh default=$(SSH_AUTH_SOCK) \ - -f Dockerfile \ - --tag tscholak/$(EVAL_IMAGE_NAME):$(GIT_HEAD_REF) \ - --tag tscholak/$(EVAL_IMAGE_NAME):cache \ - --build-arg BASE_IMAGE=$(BASE_IMAGE) \ - --target eval \ - --cache-from type=registry,ref=tscholak/$(EVAL_IMAGE_NAME):cache \ - --cache-to type=inline \ - --push \ - git@github.com:ElementAI/picard#$(GIT_HEAD_REF) +build-eval-image: + docker build . -f Dockerfile.eval -t picard-eval --build-arg BASE_IMAGE=tscholak/text-to-sql-eval:6a252386bed6d4233f0f13f4562d8ae8608e7445 + # ssh-add + # docker buildx build \ + # --builder $(BUILDKIT_BUILDER) \ + # --ssh default=$(SSH_AUTH_SOCK) \ + # -f Dockerfile \ + # --tag tscholak/$(EVAL_IMAGE_NAME):$(GIT_HEAD_REF) \ + # --tag tscholak/$(EVAL_IMAGE_NAME):cache \ + # --build-arg BASE_IMAGE=$(BASE_IMAGE) \ + # --target eval \ + # --cache-from type=registry,ref=tscholak/$(EVAL_IMAGE_NAME):cache \ + # --cache-to type=inline \ + # --push \ + # git@github.com:ElementAI/picard#$(GIT_HEAD_REF) .PHONY: pull-eval-image pull-eval-image: - docker pull tscholak/$(EVAL_IMAGE_NAME):$(GIT_HEAD_REF) + docker pull tscholak/text-to-sql-eval:6a252386bed6d4233f0f13f4562d8ae8608e7445 .PHONY: train -train: pull-train-image +train: build-train-image mkdir -p -m 777 train mkdir -p -m 777 transformers_cache mkdir -p -m 777 wandb docker run \ -it \ --rm \ - --user 13011:13011 \ - --mount type=bind,source=$(BASE_DIR)/train,target=/train \ - --mount type=bind,source=$(BASE_DIR)/transformers_cache,target=/transformers_cache \ - --mount type=bind,source=$(BASE_DIR)/configs,target=/app/configs \ - --mount type=bind,source=$(BASE_DIR)/wandb,target=/app/wandb \ - tscholak/$(TRAIN_IMAGE_NAME):$(GIT_HEAD_REF) \ - /bin/bash -c "python seq2seq/run_seq2seq.py configs/train.json" + --name picard \ + --gpus all \ + --ulimit memlock=-1:-1 \ + --ipc host \ + -v $(BASE_DIR)/train_output:/train_output \ + -v $(BASE_DIR)/transformers_cache:/transformers_cache \ + -v $(BASE_DIR)/configs:/app/configs \ + -v $(BASE_DIR)/wandb:/app/wandb \ + -v $(BASE_DIR)/data:/app/data \ + --env WANDB_API_KEY \ + -e TRANSFORMERS_CACHE=/transformers_cache \ + picard \ + /bin/bash -c "deepspeed --num_gpus=4 seq2seq/run_seq2seq.py configs/train.json" .PHONY: train_cosql train_cosql: pull-train-image @@ -131,23 +137,24 @@ train_cosql: pull-train-image --mount type=bind,source=$(BASE_DIR)/transformers_cache,target=/transformers_cache \ --mount type=bind,source=$(BASE_DIR)/configs,target=/app/configs \ --mount type=bind,source=$(BASE_DIR)/wandb,target=/app/wandb \ - tscholak/$(TRAIN_IMAGE_NAME):$(GIT_HEAD_REF) \ + picard \ /bin/bash -c "python seq2seq/run_seq2seq.py configs/train_cosql.json" .PHONY: eval -eval: pull-eval-image +eval: build-eval-image mkdir -p -m 777 eval mkdir -p -m 777 transformers_cache mkdir -p -m 777 wandb docker run \ -it \ --rm \ - --user 13011:13011 \ - --mount type=bind,source=$(BASE_DIR)/eval,target=/eval \ - --mount type=bind,source=$(BASE_DIR)/transformers_cache,target=/transformers_cache \ - --mount type=bind,source=$(BASE_DIR)/configs,target=/app/configs \ - --mount type=bind,source=$(BASE_DIR)/wandb,target=/app/wandb \ - tscholak/$(EVAL_IMAGE_NAME):$(GIT_HEAD_REF) \ + --gpus all \ + -v $(BASE_DIR)/eval_output:/eval_output \ + -v $(BASE_DIR)/transformers_cache:/transformers_cache \ + -v $(BASE_DIR)/configs:/app/configs \ + -v /xdata/train_output:/train_output \ + -e TRANSFORMERS_CACHE=/transformers_cache \ + picard-eval \ /bin/bash -c "python seq2seq/run_seq2seq.py configs/eval.json" .PHONY: eval_cosql @@ -170,6 +177,7 @@ eval_cosql: pull-eval-image serve: pull-eval-image mkdir -p -m 777 database mkdir -p -m 777 transformers_cache + docker build . -t picard -f Dockerfile.eval docker run \ -it \ --rm \ @@ -178,7 +186,8 @@ serve: pull-eval-image --mount type=bind,source=$(BASE_DIR)/database,target=/database \ --mount type=bind,source=$(BASE_DIR)/transformers_cache,target=/transformers_cache \ --mount type=bind,source=$(BASE_DIR)/configs,target=/app/configs \ - tscholak/$(EVAL_IMAGE_NAME):$(GIT_HEAD_REF) \ + --name picard \ + picard \ /bin/bash -c "python seq2seq/serve_seq2seq.py configs/serve.json" .PHONY: prediction_output diff --git a/README.md b/README.md index ccf9504a..df88f9b8 100644 --- a/README.md +++ b/README.md @@ -376,3 +376,17 @@ There are three docker images that can be used to run the code: * **[tscholak/text-to-sql-eval](https://hub.docker.com/repository/docker/tscholak/text-to-sql-eval):** Training/evaluation image with all dependencies. Use this for evaluating a fine-tuned model with Picard. This image can also be used for training if you want to run evaluation during training with Picard. Pull it with `make pull-eval-image` from the docker hub. Rebuild the image with `make build-eval-image`. All images are tagged with the current commit hash. The images are built with the buildx tool which is available in the latest docker-ce. Use `make init-buildkit` to initialize the buildx tool on your machine. You can then use `make build-dev-image`, `make build-train-image`, etc. to rebuild the images. Local changes to the code will not be reflected in the docker images unless they are committed to git. + +### Using Deepspeed +Training on 24 GB GPU was not possible for a batch size of even 1. So we need to use deepspeed. Deepspeed failed in docker silently. So we need to run on host. + +```shell +# intsall deespeed +export PATH="/usr/local/cuda-11.7/bin:$PATH" +export LD_LIBRARY_PATH="/usr/local/cuda-11.7/lib64:$LD_LIBRARY_PATH" +DS_BUILD_CPU_ADAM=1 DS_BUILD_AIO=0 pip install deepspeed --global-option="build_ext" --global-option="-j8" + +# start training +deepspeed seq2seq/run_seq2seq.py configs/train.json +``` + diff --git a/configs/ds_config_zero2.json b/configs/ds_config_zero2.json new file mode 100644 index 00000000..02c1551b --- /dev/null +++ b/configs/ds_config_zero2.json @@ -0,0 +1,54 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + "bf16": { + "enabled": "auto" + }, + + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + }, + + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "contiguous_gradients": true + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/configs/eval.json b/configs/eval.json index cce1861c..6434af97 100644 --- a/configs/eval.json +++ b/configs/eval.json @@ -1,6 +1,6 @@ { "run_name": "t5+picard-spider-eval", - "model_name_or_path": "tscholak/cxmefzzi", + "model_name_or_path": "tscholak/3vnuv1vf", "dataset": "spider", "source_prefix": "", "schema_serialization_type": "peteshaw", @@ -9,19 +9,19 @@ "schema_serialization_with_db_content": true, "normalize_query": true, "target_with_db_id": true, - "output_dir": "/eval", + "output_dir": "/eval_output", "cache_dir": "/transformers_cache", "do_train": false, "do_eval": true, "fp16": false, - "per_device_eval_batch_size": 1, + "per_device_eval_batch_size": 4, "seed": 1, - "report_to": ["wandb"], + "report_to": [], "predict_with_generate": true, "num_beams": 4, "num_beam_groups": 1, "diversity_penalty": 0.0, - "max_val_samples": 1034, + "max_val_samples": 1600, "use_picard": true, "launch_picard": true, "picard_mode": "parse_with_guards", diff --git a/configs/serve.json b/configs/serve.json index 3a929227..0fd31c81 100644 --- a/configs/serve.json +++ b/configs/serve.json @@ -1,20 +1,21 @@ { - "model_path": "tscholak/3vnuv1vf", + "model_path": "/data/checkpoint-90", "source_prefix": "", "schema_serialization_type": "peteshaw", "schema_serialization_randomized": false, "schema_serialization_with_db_id": true, "schema_serialization_with_db_content": true, + "include_foreign_keys_in_schema": true, "normalize_query": true, "target_with_db_id": true, - "db_path": "/database", - "cache_dir": "/transformers_cache", + "db_path": "/home/ubuntu/trans_cache/downloads/extracted/d712e0d61bf3021b084b5268e3189f7f8882e4938131c9c749b9e008c833cef3/spider/database/", + "cache_dir": "~/trans_cache", "num_beams": 4, "num_return_sequences": 1, - "use_picard": true, - "launch_picard": true, + "use_picard": false, + "launch_picard": false, "picard_mode": "parse_with_guards", "picard_schedule": "incremental", "picard_max_tokens_to_check": 2, - "device": 0 + "device": -1 } diff --git a/configs/train.json b/configs/train.json index 5e2891e8..d9f3b8e2 100644 --- a/configs/train.json +++ b/configs/train.json @@ -1,6 +1,6 @@ { - "run_name": "t5-spider", - "model_name_or_path": "t5-3b", + "run_name": "picard-001-fk", + "model_name_or_path": "tscholak/cxmefzzi", "dataset": "spider", "source_prefix": "", "schema_serialization_type": "peteshaw", @@ -9,36 +9,39 @@ "schema_serialization_with_db_content": true, "normalize_query": true, "target_with_db_id": true, - "output_dir": "/train", + "output_dir": "/train_output", "cache_dir": "/transformers_cache", "do_train": true, - "do_eval": true, + "do_eval": false, "fp16": false, - "num_train_epochs": 3072, - "per_device_train_batch_size": 5, - "per_device_eval_batch_size": 5, - "gradient_accumulation_steps": 410, + "num_train_epochs": 32, + "per_device_train_batch_size": 4, + "per_device_eval_batch_size": 4, + "gradient_accumulation_steps": 8, + "gradient_checkpointing": true, "label_smoothing_factor": 0.0, "learning_rate": 1e-4, - "adafactor": true, - "adam_eps": 1e-6, "lr_scheduler_type": "constant", "warmup_ratio": 0.0, "warmup_steps": 0, "seed": 1, - "report_to": ["wandb"], + "report_to": [], "logging_strategy": "steps", "logging_first_step": true, "logging_steps": 4, - "load_best_model_at_end": true, + "load_best_model_at_end": false, "metric_for_best_model": "exact_match", "greater_is_better": true, "save_total_limit": 128, + "save_strategy": "steps", "save_steps": 64, - "evaluation_strategy": "steps", - "eval_steps": 64, + "evaluation_strategy": "no", + "eval_steps": 1, "predict_with_generate": true, "num_beams": 1, "num_beam_groups": 1, - "use_picard": false + "use_picard": true, + "overwrite_output_dir": true, + "deepspeed": "configs/ds_config_zero2.json", + "overwrite_cache": false } diff --git a/nlp_picard.postman_collection.json b/nlp_picard.postman_collection.json new file mode 100644 index 00000000..2ccb4da9 --- /dev/null +++ b/nlp_picard.postman_collection.json @@ -0,0 +1,156 @@ +{ + "info": { + "_postman_id": "c47177e2-5637-48ed-9671-27acbb36fb85", + "name": "nlp_picard", + "schema": "https://schema.getpostman.com/json/collection/v2.1.0/collection.json", + "_exporter_id": "17757684" + }, + "item": [ + { + "name": "get query", + "request": { + "method": "GET", + "header": [], + "url": { + "raw": "http://54.90.26.47:8000/ask/?db_id=well&question=what is the average length of the well bore", + "protocol": "http", + "host": [ + "54", + "90", + "26", + "47" + ], + "port": "8000", + "path": [ + "ask", + "" + ], + "query": [ + { + "key": "db_id", + "value": "well" + }, + { + "key": "question", + "value": "what is the average length of the well bore" + } + ] + } + }, + "response": [] + }, + { + "name": "get databases", + "request": { + "method": "GET", + "header": [], + "url": { + "raw": "http://54.90.26.47:8000/database/", + "protocol": "http", + "host": [ + "54", + "90", + "26", + "47" + ], + "port": "8000", + "path": [ + "database", + "" + ] + } + }, + "response": [] + }, + { + "name": "get schema for database", + "request": { + "method": "GET", + "header": [], + "url": { + "raw": "http://54.90.26.47:8000/schema/well/", + "protocol": "http", + "host": [ + "54", + "90", + "26", + "47" + ], + "port": "8000", + "path": [ + "schema", + "well", + "" + ] + } + }, + "response": [] + }, + { + "name": "create schema", + "request": { + "method": "POST", + "header": [], + "body": { + "mode": "raw", + "raw": "[\n \"CREATE TABLE well(id, name)\",\n \"CREATE TABLE wellbore(id, well_id, bore_length)\"\n]", + "options": { + "raw": { + "language": "json" + } + } + }, + "url": { + "raw": "http://54.90.26.47:8000/schema/well2/", + "protocol": "http", + "host": [ + "54", + "90", + "26", + "47" + ], + "port": "8000", + "path": [ + "schema", + "well2", + "" + ] + } + }, + "response": [] + }, + { + "name": "update schema", + "request": { + "method": "PATCH", + "header": [], + "body": { + "mode": "raw", + "raw": "[\n \"CREATE TABLE well_2(id, name)\",\n \"CREATE TABLE wellbore_2(id, well_id, bore_length)\"\n]", + "options": { + "raw": { + "language": "json" + } + } + }, + "url": { + "raw": "http://54.90.26.47:8000/schema/well2/", + "protocol": "http", + "host": [ + "54", + "90", + "26", + "47" + ], + "port": "8000", + "path": [ + "schema", + "well2", + "" + ] + } + }, + "response": [] + } + ] +} \ No newline at end of file diff --git a/seq2seq/prediction_output.py b/seq2seq/prediction_output.py index 6ae07947..2df42d9b 100644 --- a/seq2seq/prediction_output.py +++ b/seq2seq/prediction_output.py @@ -1,5 +1,6 @@ # Set up logging import sys +sys.path.append('.') import logging logging.basicConfig( @@ -122,6 +123,7 @@ def get_pipeline_kwargs( "schema_serialization_type": data_training_args.schema_serialization_type, "schema_serialization_with_db_id": data_training_args.schema_serialization_with_db_id, "schema_serialization_with_db_content": data_training_args.schema_serialization_with_db_content, + "include_foreign_keys": data_training_args.include_foreign_keys_in_schema, "device": prediction_output_args.device, } diff --git a/seq2seq/run_seq2seq.py b/seq2seq/run_seq2seq.py index d55186d3..e667fb49 100644 --- a/seq2seq/run_seq2seq.py +++ b/seq2seq/run_seq2seq.py @@ -1,5 +1,6 @@ # Set up logging import sys +sys.path.append('.') import logging logging.basicConfig( @@ -30,10 +31,11 @@ from seq2seq.utils.dataset_loader import load_dataset from seq2seq.utils.spider import SpiderTrainer from seq2seq.utils.cosql import CoSQLTrainer +from seq2seq.utils.trainer import print_gpu_utilization def main() -> None: - # See all possible arguments by passing the --help flag to this script. + # See all possible arguments by passing the --help flag to this script. parser = HfArgumentParser( (PicardArguments, ModelArguments, DataArguments, DataTrainingArguments, Seq2SeqTrainingArguments) ) @@ -127,6 +129,7 @@ def main() -> None: ) # Initialize tokenizer + logger.warning('loading tokenizer...') tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, @@ -140,6 +143,7 @@ def main() -> None: tokenizer.add_tokens([AddedToken(" <="), AddedToken(" <")]) # Load dataset + logger.warning('loading dataset...') metric, dataset_splits = load_dataset( data_args=data_args, model_args=model_args, @@ -147,6 +151,7 @@ def main() -> None: training_args=training_args, tokenizer=tokenizer, ) + logger.warning('loading dataset complete') # Initialize Picard if necessary with PicardLauncher() if picard_args.launch_picard and training_args.local_rank <= 0 else nullcontext(None): @@ -159,6 +164,7 @@ def main() -> None: model_cls_wrapper = lambda model_cls: model_cls # Initialize model + logger.warning('loading model...') model = model_cls_wrapper(AutoModelForSeq2SeqLM).from_pretrained( model_args.model_name_or_path, from_tf=bool(".ckpt" in model_args.model_name_or_path), @@ -167,6 +173,9 @@ def main() -> None: revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) + print_gpu_utilization() + + if isinstance(model, T5ForConditionalGeneration): model.resize_token_embeddings(len(tokenizer)) @@ -204,7 +213,7 @@ def main() -> None: # Training if training_args.do_train: - logger.info("*** Train ***") + logger.warning("*** Train ***") checkpoint = None @@ -212,7 +221,7 @@ def main() -> None: checkpoint = training_args.resume_from_checkpoint elif last_checkpoint is not None: checkpoint = last_checkpoint - + train_result = trainer.train(resume_from_checkpoint=checkpoint) trainer.save_model() # Saves the tokenizer too for easy upload diff --git a/seq2seq/serve_seq2seq.py b/seq2seq/serve_seq2seq.py index 9cd5875a..58f7e1fc 100644 --- a/seq2seq/serve_seq2seq.py +++ b/seq2seq/serve_seq2seq.py @@ -1,5 +1,6 @@ # Set up logging import sys +sys.path.append('.') import logging logging.basicConfig( @@ -9,6 +10,7 @@ level=logging.WARNING, ) logger = logging.getLogger(__name__) +from loguru import logger from typing import Optional, Dict from dataclasses import dataclass, field @@ -20,9 +22,16 @@ from fastapi import FastAPI, HTTPException from uvicorn import run from sqlite3 import Connection, connect, OperationalError -from seq2seq.utils.pipeline import Text2SQLGenerationPipeline, Text2SQLInput, get_schema +from seq2seq.utils.pipeline import (Text2SQLGenerationPipeline, Text2SQLGenPipelineWithSchema, + Text2SQLInput, QuestionWithSchemaInput, get_schema, get_schema_for_display, + get_db_file_path) from seq2seq.utils.picard_model_wrapper import PicardArguments, PicardLauncher, with_picard +from seq2seq.utils.dataset import serialize_schema from seq2seq.utils.dataset import DataTrainingArguments +from seq2seq.utils.spider import spider_get_input +import sqlite3 +from pathlib import Path +from typing import List @dataclass @@ -67,6 +76,7 @@ def main(): picard_args, backend_args, data_training_args = parser.parse_args_into_dataclasses() # Initialize config + logger.info(f'loading model...') config = AutoConfig.from_pretrained( backend_args.model_path, cache_dir=backend_args.cache_dir, @@ -113,9 +123,21 @@ def main(): device=backend_args.device, ) + pipe_with_schema = Text2SQLGenPipelineWithSchema( + model = model, + tokenizer = tokenizer, + db_path = backend_args.db_path, + normalize_query = data_training_args.normalize_query, + device = backend_args.device) + + # Initialize REST API app = FastAPI() + class Query(BaseModel): + question: str + db_schema: str + class AskResponse(BaseModel): query: str execution_results: list @@ -128,7 +150,7 @@ def response(query: str, conn: Connection) -> AskResponse: status_code=500, detail=f'while executing "{query}", the following error occurred: {e.args[0]}' ) - @app.get("/ask/{db_id}/{question}") + @app.get("/ask/") def ask(db_id: str, question: str): try: outputs = pipe( @@ -143,9 +165,107 @@ def ask(db_id: str, question: str): finally: conn.close() + + @app.post("/ask-with-schema/") + def ask_with_schema(query: Query): + try: + outputs = pipe_with_schema( + inputs = QuestionWithSchemaInput(utterance=query.question, schema=query.db_schema), + num_return_sequences=data_training_args.num_return_sequences + ) + except OperationalError as e: + raise HTTPException(status_code=404, detail=e.args[0]) + + return [output["generated_text"] for output in outputs] + + + @app.get("/database/") + def get_database_list(): + db_dir = Path(backend_args.db_path) + + print(f'db_path - {db_dir}') + db_files = db_dir.rglob("*.sqlite") + return [db_file.stem for db_file in db_files if db_file.stem == db_file.parent.stem] + + @app.get("/schema/{db_id}") + def get_schema_for_database(db_id): + return get_schema(backend_args.db_path, db_id) + + @app.get("/serialized-schema/{db_id}/") + def get_serialized_schema(db_id, schema_serialization_type = "peteshaw", + schema_serialization_randomized = False, + schema_serialization_with_db_id = True, + schema_serialization_with_db_content = False + ): + schema = pipe_with_schema.get_schema_from_cache(db_id) + serialized_schema = serialize_schema(question='question', + db_path = backend_args.db_path, + db_id = db_id, + db_column_names = schema['db_column_names'], + db_table_names = schema['db_table_names'], + schema_serialization_type = schema_serialization_type, + schema_serialization_randomized = schema_serialization_randomized, + schema_serialization_with_db_id = schema_serialization_with_db_id, + schema_serialization_with_db_content = schema_serialization_with_db_content, + include_foreign_keys=data_training_args.include_foreign_keys_in_schema, + foreign_keys=schema['db_foreign_keys'] + ) + return spider_get_input('question', serialized_schema, prefix='') + + + @app.post("/schema/{db_id}") + def create_schema(db_id, queries: List[str]): + db_file_path = Path(get_db_file_path(backend_args.db_path, db_id)) + + if db_file_path.exists(): + raise HTTPException(status_code=409, detail="database already exists") + + # create parent directory if it doesn't exist + db_file_path.parent.mkdir(parents=True, exist_ok=True) + + print(f'creating database {db_file_path.as_posix()}...') + + con = sqlite3.connect(db_file_path.as_posix()) + cur = con.cursor() + try: + for query in queries: + cur.execute(query) + con.commit() + except OperationalError as e: + raise HTTPException(status_code=400, detail=e.args[0]) + finally: + con.close() + + return get_schema(backend_args.db_path, db_id) + + @app.patch("/schema/{db_id}") + def update_schema(db_id, queries: List[str]): + db_file_path = Path(get_db_file_path(backend_args.db_path, db_id)) + + if not db_file_path.exists(): + raise HTTPException(status_code=404, detail="database not found") + + print(f'updating database {db_file_path.as_posix()}...') + + con = sqlite3.connect(db_file_path.as_posix()) + cur = con.cursor() + try: + for query in queries: + cur.execute(query) + con.commit() + except OperationalError as e: + raise HTTPException(status_code=400, detail=e.args[0]) + finally: + con.close() + + return get_schema(backend_args.db_path, db_id) + + + # Run app run(app=app, host=backend_args.host, port=backend_args.port) if __name__ == "__main__": + print('serving....') main() diff --git a/seq2seq/utils/cosql.py b/seq2seq/utils/cosql.py index 28121aa7..d7baea49 100644 --- a/seq2seq/utils/cosql.py +++ b/seq2seq/utils/cosql.py @@ -47,6 +47,7 @@ def cosql_add_serialized_schema( schema_serialization_with_db_id=data_training_args.schema_serialization_with_db_id, schema_serialization_with_db_content=data_training_args.schema_serialization_with_db_content, normalize_query=data_training_args.normalize_query, + foreign_keys=ex["db_foreign_keys"], ) return {"serialized_schema": serialized_schema} diff --git a/seq2seq/utils/dataset.py b/seq2seq/utils/dataset.py index d92bf0d4..d84461dc 100644 --- a/seq2seq/utils/dataset.py +++ b/seq2seq/utils/dataset.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Dict, Callable +from typing import Optional, List, Dict, Callable, Tuple from dataclasses import dataclass, field from datasets.dataset_dict import DatasetDict from datasets.arrow_dataset import Dataset @@ -6,6 +6,11 @@ from seq2seq.utils.bridge_content_encoder import get_database_matches import re import random +import os +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) @dataclass @@ -19,7 +24,7 @@ class DataTrainingArguments: metadata={"help": "Overwrite the cached training and evaluation sets"}, ) preprocessing_num_workers: Optional[int] = field( - default=None, + default=int(os.cpu_count() * 0.75) if os.cpu_count() is not None else 4, # set to half of the number of CPUs metadata={"help": "The number of processes to use for the preprocessing."}, ) max_source_length: Optional[int] = field( @@ -125,7 +130,10 @@ class DataTrainingArguments: default=True, metadata={"help": "Whether or not to add the database id to the target. Needed for Picard."}, ) - + include_foreign_keys_in_schema: bool = field( + default=True, + metadata={"help": "Whether or not to include foreign keys in the schema."}, + ) def __post_init__(self): if self.val_max_target_length is None: self.val_max_target_length = self.max_target_length @@ -354,7 +362,10 @@ def serialize_schema( schema_serialization_with_db_id: bool = True, schema_serialization_with_db_content: bool = False, normalize_query: bool = True, + include_foreign_keys: bool = False, + foreign_keys: Optional[List[Tuple[str, str]]] = None ) -> str: + # logger.warning(f'foreign keys for {db_id} is {foreign_keys}. tb_tables - {db_table_names}') if schema_serialization_type == "verbose": db_id_str = "Database: {db_id}. " table_sep = ". " @@ -375,8 +386,21 @@ def serialize_schema( else: raise NotImplementedError - def get_column_str(table_name: str, column_name: str) -> str: + def get_column_str(table_id: int, table_name: str, column_name: str, include_foreign_keys:bool) -> str: column_name_str = column_name.lower() if normalize_query else column_name + if include_foreign_keys: + # get location of fk in foreign_keys list + column_id = db_column_names['column_name'].index(column_name) if column_name in db_column_names['column_name'] else None + fk_idx = foreign_keys['column_id'].index(column_id) if column_id in foreign_keys['column_id'] else None + if fk_idx is not None: + other_column_id = foreign_keys['other_column_id'][fk_idx] + other_table_id = db_column_names['table_id'][other_column_id] + other_table_name = db_table_names[other_table_id] + other_column_name = db_column_names['column_name'][other_column_id] + fk_str = f'__fk__{other_table_name}.{other_column_name}' + column_name_str = column_name_str + fk_str + + if schema_serialization_with_db_content: matches = get_database_matches( question=question, @@ -396,7 +420,8 @@ def get_column_str(table_name: str, column_name: str) -> str: table=table_name.lower() if normalize_query else table_name, columns=column_sep.join( map( - lambda y: get_column_str(table_name=table_name, column_name=y[1]), + lambda y: get_column_str(table_id=table_id, table_name=table_name, column_name=y[1], + include_foreign_keys=include_foreign_keys), filter( lambda y: y[0] == table_id, zip( @@ -415,4 +440,5 @@ def get_column_str(table_name: str, column_name: str) -> str: serialized_schema = db_id_str.format(db_id=db_id) + table_sep.join(tables) else: serialized_schema = table_sep.join(tables) + # logger.warning(f'serialized schema for {db_id} is {serialized_schema}.') return serialized_schema diff --git a/seq2seq/utils/picard_model_wrapper.py b/seq2seq/utils/picard_model_wrapper.py index 1d771573..59befcb1 100644 --- a/seq2seq/utils/picard_model_wrapper.py +++ b/seq2seq/utils/picard_model_wrapper.py @@ -11,8 +11,10 @@ import torch from transformers import LogitsProcessorList from transformers.configuration_utils import PretrainedConfig -from transformers.generation_utils import GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput -from transformers.generation_logits_process import LogitsProcessor +# from transformers.generation_utils import GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput +from transformers.generation import GreedySearchEncoderDecoderOutput, SampleEncoderDecoderOutput, BeamSearchEncoderDecoderOutput, BeamSampleEncoderDecoderOutput +# from transformers.generation_logits_process import LogitsProcessor +from transformers import LogitsProcessor from transformers.file_utils import copy_func from transformers.models.auto.auto_factory import _get_model_class from transformers.models.auto.configuration_auto import AutoConfig @@ -158,7 +160,8 @@ def _generate( logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), eos_token_id: Optional[int] = None, **kwargs, - ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: + ) -> Union[GreedySearchEncoderDecoderOutput, SampleEncoderDecoderOutput, BeamSearchEncoderDecoderOutput, + BeamSampleEncoderDecoderOutput, torch.LongTensor]: eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id logits_processor.append( diff --git a/seq2seq/utils/pipeline.py b/seq2seq/utils/pipeline.py index 6f5c51cb..d10bea79 100644 --- a/seq2seq/utils/pipeline.py +++ b/seq2seq/utils/pipeline.py @@ -8,12 +8,20 @@ from seq2seq.utils.spider import spider_get_input from seq2seq.utils.cosql import cosql_get_input +import logging +logger = logging.getLogger(__name__) + @dataclass class Text2SQLInput(object): utterance: str db_id: str +@dataclass +class QuestionWithSchemaInput(object): + utterance: str + schema: str + class Text2SQLGenerationPipeline(Text2TextGenerationPipeline): """ @@ -35,7 +43,7 @@ class Text2SQLGenerationPipeline(Text2TextGenerationPipeline): """ def __init__(self, *args, **kwargs): - self.db_path: str = kwargs.pop("db_path") + self.db_path: str = kwargs.pop("db_path", None) self.prefix: Optional[str] = kwargs.pop("prefix", None) self.normalize_query: bool = kwargs.pop("normalize_query", True) self.schema_serialization_type: str = kwargs.pop("schema_serialization_type", "peteshaw") @@ -43,6 +51,8 @@ def __init__(self, *args, **kwargs): self.schema_serialization_with_db_id: bool = kwargs.pop("schema_serialization_with_db_id", True) self.schema_serialization_with_db_content: bool = kwargs.pop("schema_serialization_with_db_content", True) self.schema_cache: Dict[str, dict] = dict() + self.include_foreign_keys = kwargs.pop("include_foreign_keys_in_schema", True) + logger.warning(f'include_foreign_keys 2 is {self.include_foreign_keys}') super().__init__(*args, **kwargs) def __call__(self, inputs: Union[Text2SQLInput, List[Text2SQLInput]], *args, **kwargs): @@ -74,6 +84,7 @@ def __call__(self, inputs: Union[Text2SQLInput, List[Text2SQLInput]], *args, **k -- The token ids of the generated SQL. """ result = super().__call__(inputs, *args, **kwargs) + print(f'with db output is :{result}') if ( isinstance(inputs, list) and all(isinstance(el, Text2SQLInput) for el in inputs) @@ -116,11 +127,15 @@ def _parse_and_tokenize( del encodings["token_type_ids"] return encodings + + def get_schema_from_cache(self, db_id): + if db_id not in self.schema_cache: + self.schema_cache[db_id] = get_schema(db_path=self.db_path, db_id=db_id) + return self.schema_cache[db_id] + def _pre_process(self, input: Text2SQLInput) -> str: prefix = self.prefix if self.prefix is not None else "" - if input.db_id not in self.schema_cache: - self.schema_cache[input.db_id] = get_schema(db_path=self.db_path, db_id=input.db_id) - schema = self.schema_cache[input.db_id] + schema = self.get_schema_from_cache(input.db_id) if hasattr(self.model, "add_schema"): self.model.add_schema(db_id=input.db_id, db_info=schema) serialized_schema = serialize_schema( @@ -134,8 +149,12 @@ def _pre_process(self, input: Text2SQLInput) -> str: schema_serialization_with_db_id=self.schema_serialization_with_db_id, schema_serialization_with_db_content=self.schema_serialization_with_db_content, normalize_query=self.normalize_query, + include_foreign_keys=self.include_foreign_keys, + foreign_keys=schema["db_foreign_keys"], ) - return spider_get_input(question=input.utterance, serialized_schema=serialized_schema, prefix=prefix) + spider_input = spider_get_input(question=input.utterance, serialized_schema=serialized_schema, prefix=prefix) + print(f'spider input is:{spider_input}') + return spider_input def postprocess(self, model_outputs: dict, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=False): records = [] @@ -155,6 +174,92 @@ def postprocess(self, model_outputs: dict, return_type=ReturnType.TEXT, clean_up records.append(record) return records +class Text2SQLGenPipelineWithSchema(Text2SQLGenerationPipeline): + """ + Pipeline for text-to-SQL generation using seq2seq models. Here Database schema is passed along with query + + model = AutoModelForSeq2SeqLM.from_pretrained(...) + tokenizer = AutoTokenizer.from_pretrained(...) + db_path = ... path to "concert_singer" parent folder + text2sql_generator = Text2SQLGenerationPipeline( + model=model, + tokenizer=tokenizer, + ) + text2sql_generator(inputs=Text2SQLInput(utterance="How many singers do we have?", db_id="concert_singer")) + """ + def __init__(self, *args, **kwargs): + self.normalize_query: bool = kwargs.pop("normalize_query", True) + super().__init__(*args, **kwargs) + + def _pre_process(self, input: QuestionWithSchemaInput) -> str: + # prefix = self.prefix if self.prefix is not None else "" + spider_input = spider_get_input(question=input.utterance, serialized_schema=input.schema, prefix='') + print(f'spider input is :{spider_input}') + return spider_input + + def __call__(self, inputs: Union[QuestionWithSchemaInput, List[QuestionWithSchemaInput]], *args, **kwargs): + r""" + Generate the output SQL expression(s) using text(s) given as inputs. + + Args: + inputs (:obj:`Text2SQLInput` or :obj:`List[Text2SQLInput]`): + Input text(s) for the encoder. + return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to include the tensors of predictions (as token indices) in the outputs. + return_text (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to include the decoded texts in the outputs. + clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to clean up the potential extra spaces in the text output. + truncation (:obj:`TruncationStrategy`, `optional`, defaults to :obj:`TruncationStrategy.DO_NOT_TRUNCATE`): + The truncation strategy for the tokenization within the pipeline. + :obj:`TruncationStrategy.DO_NOT_TRUNCATE` (default) will never truncate, but it is sometimes desirable + to truncate the input to fit the model's max_length instead of throwing an error down the line. + generate_kwargs: + Additional keyword arguments to pass along to the generate method of the model (see the generate method + corresponding to your framework `here <./model.html#generative-models>`__). + + Return: + A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the following keys: + + - **generated_sql** (:obj:`str`, present when ``return_text=True``) -- The generated SQL. + - **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``) + -- The token ids of the generated SQL. + """ + result = super().__call__(inputs, *args, **kwargs) + print(f'with schema output is :{result}') + if ( + isinstance(inputs, list) + and all(isinstance(el, QuestionWithSchemaInput) for el in inputs) + and all(len(res) == 1 for res in result) + ): + return [res[0] for res in result] + return result + + # no changes from parent class other than input type + def _parse_and_tokenize( + self, + inputs: Union[QuestionWithSchemaInput, List[QuestionWithSchemaInput]], + *args, + truncation: TruncationStrategy + ) -> BatchEncoding: + if isinstance(inputs, list): + if self.tokenizer.pad_token_id is None: + raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input") + inputs = [self._pre_process(input=input) for input in inputs] + padding = True + elif isinstance(inputs, QuestionWithSchemaInput): + inputs = self._pre_process(input=inputs) + padding = False + else: + raise ValueError( + f" `inputs`: {inputs} have the wrong format. The should be either of type `Text2SQLInput` or type `List[Text2SQLInput]`" + ) + encodings = self.tokenizer(inputs, padding=padding, truncation=truncation, return_tensors=self.framework) + # This is produced by tokenizers but is an invalid generate kwargs + if "token_type_ids" in encodings: + del encodings["token_type_ids"] + return encodings + @dataclass class ConversationalText2SQLInput(object): @@ -182,6 +287,7 @@ class ConversationalText2SQLGenerationPipeline(Text2TextGenerationPipeline): """ def __init__(self, *args, **kwargs): + logger.warning(f'kwargs is :{kwargs}') self.db_path: str = kwargs.pop("db_path") self.prefix: Optional[str] = kwargs.pop("prefix", None) self.normalize_query: bool = kwargs.pop("normalize_query", True) @@ -190,6 +296,8 @@ def __init__(self, *args, **kwargs): self.schema_serialization_with_db_id: bool = kwargs.pop("schema_serialization_with_db_id", True) self.schema_serialization_with_db_content: bool = kwargs.pop("schema_serialization_with_db_content", True) self.schema_cache: Dict[str, dict] = dict() + self.include_foreign_keys = kwargs.pop("include_foreign_keys", False) + logger.warning(f'include foreign keys is :{self.include_foreign_keys}') super().__init__(*args, **kwargs) def __call__(self, inputs: Union[ConversationalText2SQLInput, List[ConversationalText2SQLInput]], *args, **kwargs): @@ -265,8 +373,9 @@ def _parse_and_tokenize( def _pre_process(self, input: ConversationalText2SQLInput) -> str: prefix = self.prefix if self.prefix is not None else "" - if input.db_id not in self.schema_cache: - self.schema_cache[input.db_id] = get_schema(db_path=self.db_path, db_id=input.db_id) + # if input.db_id not in self.schema_cache: + # self.schema_cache[input.db_id] = get_schema(db_path=self.db_path, db_id=input.db_id) + schema = self.get_schema_from_cache(input.db_id) schema = self.schema_cache[input.db_id] if hasattr(self.model, "add_schema"): self.model.add_schema(db_id=input.db_id, db_info=schema) @@ -281,6 +390,8 @@ def _pre_process(self, input: ConversationalText2SQLInput) -> str: schema_serialization_with_db_id=self.schema_serialization_with_db_id, schema_serialization_with_db_content=self.schema_serialization_with_db_content, normalize_query=self.normalize_query, + include_foreign_keys=self.include_foreign_keys_in_schema, + foreign_keys = schema["db_foreign_keys"] ) return cosql_get_input(utterances=input.utterances, serialized_schema=serialized_schema, prefix=prefix) @@ -302,9 +413,13 @@ def postprocess(self, model_outputs: dict, return_type=ReturnType.TEXT, clean_up records.append(record) return records +def get_db_file_path(db_path: str, db_id: str) -> str: + return db_path + "/" + db_id + "/" + db_id + ".sqlite" def get_schema(db_path: str, db_id: str) -> dict: - schema = dump_db_json_schema(db_path + "/" + db_id + "/" + db_id + ".sqlite", db_id) + db_file_path = db_path + "/" + db_id + "/" + db_id + ".sqlite" + print(f'reading schema from {db_file_path}') + schema = dump_db_json_schema(db_file_path, db_id) return { "db_table_names": schema["table_names_original"], "db_column_names": { @@ -318,3 +433,8 @@ def get_schema(db_path: str, db_id: str) -> dict: "other_column_id": [other_column_id for _, other_column_id in schema["foreign_keys"]], }, } + +def get_schema_for_display(db_path: str, db_id: str) -> dict: + db_file_path = db_path + "/" + db_id + "/" + db_id + ".sqlite" + schema = dump_db_json_schema(db_file_path, db_id) + return schema \ No newline at end of file diff --git a/seq2seq/utils/spider.py b/seq2seq/utils/spider.py index 9bf5607e..14ed2fbb 100644 --- a/seq2seq/utils/spider.py +++ b/seq2seq/utils/spider.py @@ -37,6 +37,8 @@ def spider_add_serialized_schema(ex: dict, data_training_args: DataTrainingArgum schema_serialization_with_db_id=data_training_args.schema_serialization_with_db_id, schema_serialization_with_db_content=data_training_args.schema_serialization_with_db_content, normalize_query=data_training_args.normalize_query, + include_foreign_keys=data_training_args.include_foreign_keys_in_schema, + foreign_keys=ex["db_foreign_keys"] ) return {"serialized_schema": serialized_schema} diff --git a/seq2seq/utils/trainer.py b/seq2seq/utils/trainer.py index 2d0f3253..6fffab6f 100644 --- a/seq2seq/utils/trainer.py +++ b/seq2seq/utils/trainer.py @@ -6,6 +6,20 @@ from datasets.metric import Metric import numpy as np import time +from pynvml import * +from loguru import logger + +def print_gpu_utilization(): + nvmlInit() + handle = nvmlDeviceGetHandleByIndex(0) + info = nvmlDeviceGetMemoryInfo(handle) + logger.info(f"GPU memory occupied: {info.used//1024**2} MB.") + + +def print_summary(result): + print(f"Time: {result.metrics['train_runtime']:.2f}") + print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}") + print_gpu_utilization() class EvalPrediction(NamedTuple): @@ -48,10 +62,12 @@ def evaluate( max_length: Optional[int] = None, max_time: Optional[int] = None, num_beams: Optional[int] = None, + **gen_kwargs ) -> Dict[str, float]: self._max_length = max_length self._max_time = max_time self._num_beams = num_beams + self._gen_kwargs = gen_kwargs # memory metrics - must set up as early as possible self._memory_tracker.start()