diff --git a/spunky-sputniks/.github/workflows/build.yml b/spunky-sputniks/.github/workflows/build.yml new file mode 100644 index 0000000..b98e95d --- /dev/null +++ b/spunky-sputniks/.github/workflows/build.yml @@ -0,0 +1,49 @@ +name: Build + +on: + push: + tags: + - v* + +concurrency: + group: build-${{ github.head_ref }} + +jobs: + build: + name: Build wheels and source distribution + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Install build dependencies + run: python -m pip install --upgrade build + + - name: Build + run: python -m build + + - uses: actions/upload-artifact@v3 + with: + name: artifacts + path: dist/* + if-no-files-found: error + + publish: + name: Publish release + needs: + - build + if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags') + runs-on: ubuntu-latest + + steps: + - uses: actions/download-artifact@v3 + with: + name: artifacts + path: dist + + - name: Push build artifacts to PyPI + uses: pypa/gh-action-pypi-publish@v1.5.1 + with: + skip_existing: true + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/spunky-sputniks/.github/workflows/lint.yaml b/spunky-sputniks/.github/workflows/lint.yaml new file mode 100644 index 0000000..b28782a --- /dev/null +++ b/spunky-sputniks/.github/workflows/lint.yaml @@ -0,0 +1,35 @@ +# GitHub Action workflow enforcing our code style. + +name: Lint + +# Trigger the workflow on both push (to the main repository, on the main branch) +# and pull requests (against the main repository, but from any repo, from any branch). +on: + push: + branches: + - main + pull_request: + +# Brand new concurrency setting! This ensures that not more than one run can be triggered for the same commit. +# It is useful for pull requests coming from the main repository since both triggers will match. +concurrency: lint-${{ github.sha }} + +jobs: + lint: + runs-on: ubuntu-latest + + env: + # The Python version your project uses. Feel free to change this if required. + PYTHON_VERSION: "3.12" + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python ${{ env.PYTHON_VERSION }} + uses: actions/setup-python@v5 + with: + python-version: ${{ env.PYTHON_VERSION }} + + - name: Run pre-commit hooks + uses: pre-commit/action@v3.0.1 diff --git a/spunky-sputniks/.github/workflows/tests.yml b/spunky-sputniks/.github/workflows/tests.yml new file mode 100644 index 0000000..25ebd6d --- /dev/null +++ b/spunky-sputniks/.github/workflows/tests.yml @@ -0,0 +1,41 @@ +name: Tests + +on: + push: + branches: + - main + paths: + - src/discobase/** + pull_request: + branches: + - main + paths: + - src/discobase/** + +concurrency: + group: test-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + TEST_BOT_TOKEN: ${{ secrets.TEST_BOT_TOKEN }} + +jobs: + run-container-matrix: + name: Test matrix on Linux + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python 3.12 + uses: actions/setup-python@v4 + with: + python-version: "3.12" + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Run tests in matrix + run: hatch test --all diff --git a/spunky-sputniks/.gitignore b/spunky-sputniks/.gitignore new file mode 100644 index 0000000..87ba13b --- /dev/null +++ b/spunky-sputniks/.gitignore @@ -0,0 +1,35 @@ +# Files generated by the interpreter +__pycache__/ +*.py[cod] + +# Environment specific +.venv +venv +.env +env + +# Unittest reports +.coverage* + +# Logs +*.log + +# PyEnv version selector +.python-version + +# Built objects +*.so +dist/ +build/ + +# IDEs +# PyCharm +.idea/ +# VSCode +.vscode/ +# MacOS +.DS_Store + +# Builds +dist/ +test.py diff --git a/spunky-sputniks/.pre-commit-config.yaml b/spunky-sputniks/.pre-commit-config.yaml new file mode 100644 index 0000000..b6c64b6 --- /dev/null +++ b/spunky-sputniks/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +# Pre-commit configuration. +# See https://github.com/python-discord/code-jam-template/tree/main#pre-commit-run-linting-before-committing + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + args: [--markdown-linebreak-ext=md] + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.3 + hooks: + - id: ruff + + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (python) diff --git a/spunky-sputniks/LICENSE.txt b/spunky-sputniks/LICENSE.txt new file mode 100644 index 0000000..5a04926 --- /dev/null +++ b/spunky-sputniks/LICENSE.txt @@ -0,0 +1,7 @@ +Copyright 2021 Python Discord + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/spunky-sputniks/README.md b/spunky-sputniks/README.md new file mode 100644 index 0000000..a3c0182 --- /dev/null +++ b/spunky-sputniks/README.md @@ -0,0 +1,57 @@ +
+ discobase logo +

+
Python Discord Codejam 2024 Submission: Spunky Sputniks
+
+
+ +## Installation + +### Library + +```bash +$ pip install discobase +``` + +### Demo Bot + +You can add the demo bot to a server with [this integration](https://discord.com/oauth2/authorize?client_id=1268247436699238542&permissions=8&integration_type=0&scope=bot), or self-host it using the following commands: + +```bash +$ git clone https://github.com/zerointensity/discobase +$ cd discobase/src/demo +$ export DB_BOT_TOKEN="first bot token" +$ export BOOKMARK_BOT_TOKEN="second bot token" +$ python3 main.py +``` + +## Quickstart + +```py +import asyncio +import discobase + +db = discobase.Database("My database") + +@db.table +class User(discobase.Table): + name: str + password: str + +async def main(): + async with db.conn("My bot token"): + admin = await User.find(name="admin") + if not admin: + User.save(name="admin", password="admin") + +if __name__ == "__main__": + asyncio.run(main()) +``` + +## Documentation + +Documentation is available [here](https://discobase.zintensity.dev). + +## License + +`discobase` is distributed under the `MIT` license. diff --git a/spunky-sputniks/docs/assets/column_cmd.gif b/spunky-sputniks/docs/assets/column_cmd.gif new file mode 100644 index 0000000..ebe503c Binary files /dev/null and b/spunky-sputniks/docs/assets/column_cmd.gif differ diff --git a/spunky-sputniks/docs/assets/delete_cmd.gif b/spunky-sputniks/docs/assets/delete_cmd.gif new file mode 100644 index 0000000..404e0ef Binary files /dev/null and b/spunky-sputniks/docs/assets/delete_cmd.gif differ diff --git a/spunky-sputniks/docs/assets/demo_bot.gif b/spunky-sputniks/docs/assets/demo_bot.gif new file mode 100644 index 0000000..329a973 Binary files /dev/null and b/spunky-sputniks/docs/assets/demo_bot.gif differ diff --git a/spunky-sputniks/docs/assets/discobase_blurple.png b/spunky-sputniks/docs/assets/discobase_blurple.png new file mode 100644 index 0000000..28372de Binary files /dev/null and b/spunky-sputniks/docs/assets/discobase_blurple.png differ diff --git a/spunky-sputniks/docs/assets/discobase_white.png b/spunky-sputniks/docs/assets/discobase_white.png new file mode 100644 index 0000000..62eedb9 Binary files /dev/null and b/spunky-sputniks/docs/assets/discobase_white.png differ diff --git a/spunky-sputniks/docs/assets/find_cmd.gif b/spunky-sputniks/docs/assets/find_cmd.gif new file mode 100644 index 0000000..0c0d7a4 Binary files /dev/null and b/spunky-sputniks/docs/assets/find_cmd.gif differ diff --git a/spunky-sputniks/docs/assets/insert_cmd.gif b/spunky-sputniks/docs/assets/insert_cmd.gif new file mode 100644 index 0000000..531a102 Binary files /dev/null and b/spunky-sputniks/docs/assets/insert_cmd.gif differ diff --git a/spunky-sputniks/docs/assets/reset_cmd.gif b/spunky-sputniks/docs/assets/reset_cmd.gif new file mode 100644 index 0000000..dfa38ca Binary files /dev/null and b/spunky-sputniks/docs/assets/reset_cmd.gif differ diff --git a/spunky-sputniks/docs/assets/schema_cmd.gif b/spunky-sputniks/docs/assets/schema_cmd.gif new file mode 100644 index 0000000..40e7082 Binary files /dev/null and b/spunky-sputniks/docs/assets/schema_cmd.gif differ diff --git a/spunky-sputniks/docs/assets/table_cmd.gif b/spunky-sputniks/docs/assets/table_cmd.gif new file mode 100644 index 0000000..cd34cce Binary files /dev/null and b/spunky-sputniks/docs/assets/table_cmd.gif differ diff --git a/spunky-sputniks/docs/assets/tablestats_cmd.gif b/spunky-sputniks/docs/assets/tablestats_cmd.gif new file mode 100644 index 0000000..c080169 Binary files /dev/null and b/spunky-sputniks/docs/assets/tablestats_cmd.gif differ diff --git a/spunky-sputniks/docs/assets/update_cmd.gif b/spunky-sputniks/docs/assets/update_cmd.gif new file mode 100644 index 0000000..720307c Binary files /dev/null and b/spunky-sputniks/docs/assets/update_cmd.gif differ diff --git a/spunky-sputniks/docs/demonstration.md b/spunky-sputniks/docs/demonstration.md new file mode 100644 index 0000000..65891bd --- /dev/null +++ b/spunky-sputniks/docs/demonstration.md @@ -0,0 +1,20 @@ +--- +hide: + - navigation +--- + +# Demonstration + +Now, we know that Discobase's main functionality is just a database for storing data: we're allowing developers to access this "database" to make their own applications. + +Quotes are applied around the term database because, as mentioned earlier, the database is actually a Discord server created by the user (or by the Discobase bot in many cases), making Discobase the intermediary to couple this connection. + +If we look at the scenairo below, a developer—from our team—has programmed a message bookmarking bot that uses the Discord context menu feature `App -> Bookmark` to store away the message onto the database. + +![demo_bot](assets/demo_bot.gif) + +This bot doesn't use any other database — no SQL, no MongoDB, nothing! Just our very own library, which stores it on a Discord server. + +After storing away our pertinent message to the database, we can use a slash command as per the developer's generous design to get all the bookmarks we've stored away. + +If you want to try this for yourself, you can [invite the bookmark bot](https://discord.com/oauth2/authorize?client_id=1268247436699238542&permissions=8&integration_type=0&scope=bot) to your server. diff --git a/spunky-sputniks/docs/discord_interface.md b/spunky-sputniks/docs/discord_interface.md new file mode 100644 index 0000000..439fd5b --- /dev/null +++ b/spunky-sputniks/docs/discord_interface.md @@ -0,0 +1,190 @@ +--- +hide: + - navigation +--- + +# Discord Interface + +A handful of essential commands are readily available for interacting with the Discobase discord bot. + +!!! note + + The commands shown in the example section will generally have a interface provided by Discord. + In these examples, we use a **Games** table which has the columns: **Name** and **Genre**. + +## Access the Table's Schema + +Checkout the data type for the columns in your table before performing `insert` or `update` operations. + +The `/schema` operation takes in the name of your table as input and outputs information such as the names of columns and their datatypes you have set them to. + +### Usage + +`/schema [table]` + +- **Table:** The name of the table you've created. + +### Example + +`/schema Games` + +![schema](assets/schema_cmd.gif) + +!!! warning "Limitation" + + - Considering the limit of fields is 25 on discord. The command can only show up to 25 columns, so we'll signify the limit as `field_length = 25` forming the following inequality: **C** <= `field_length` where **C** is the number of columns. + +## Update a Column's Value + +Users can modify the arbitrary value they have set to a specific column in their data; however, the data type has to be consistent with the column's data type. + +The `/update` slash command takes the following parameters: the name of the table, the name of the column, the old value, and the new value that should replace the old one. + +### Usage + +`/update [table] [column] [current_value] [new_value]` + +- **Table:** The table you want to perform an update on. +- **Column:** The column you want to update. +- **Current Value**: The current value saved in the column. +- **New Value**: Your new information. + +### Example + +`/update Games name Hit Man Absolution Tomb Raider` + +![update](assets/update_cmd.gif) + +!!! warning "Limitation" + + The user is disallowed from entering a new value that is not consistent with the predefined column's data type. + +## Retrieve Statistics Concerning Your Database + +Knowing pertinent information such as how many tables are in my database and what are the names of each table are easily answered using this command. + +`/tablestats` iterates over your database's tables to display the names you've assigned to them and it tallies up a count of how many you've made. + +### Usage + +`/tablestats` + +- There are no parameters for this command. + +### Example + +`/tablestats` + +![table_stats](assets/tablestats_cmd.gif) + +## Perform a Search on Your Data + +Finding information in your data is an essential task. + +The slash command `/find`will ask for the following information before performing a search such as the name of the table, the name of the column, and the value you want to look up. + +### Usage + +`/find [table] [column] [current_value]` + +- **Table:** The name of the table the column is in. +- **Column:** The name of the column. +- **Current Value**: The value to search for. + +### Example + +`/find Games name Batman` + +![find_cmd](assets/find_cmd.gif) + +!!! warning "Limitation" + + The `description` field in a rich embed is limited to `4096` characters. The searching being performed on the data only looks for an exact match. + +## Inserting a New Record Into Your Table + +The `/insert` command allows you to add data to your table directly from your discord database server, saving you from having to restart your bot to add additional records. + +### Usage + +`/insert [table] [record]` + +- **Table:** The table you want to insert a new record into. +- **Record:** The record you want to insert, formatted as JSON. Use the columns as keys, and record data as values. + +### Example + +`/insert games {"name": "rayman legends", "genre": "platformer"}` + +![insert_cmd](assets/insert_cmd.gif) + +## Deleting a Record From Your Table + +The `/delete` command allows you to delete a record from your table within your Discord database server. Just like `/insert`, it saves you from having to restart your bot just to delete a record. + +### Usage + +`/delete [table] [record]` + +- **Table:** The table you want to delete a record from. +- **Record:** The record you want to delete, formatted as a json. Use the columns as keys, and record data as values. + +### Example + +`/delete games {"name": "rayman legends", "genre": "platformer"}` + +![delete_cmd](assets/delete_cmd.gif) + +## Resetting Your Database + +Ever had a large database that you simply do not know what to do with anymore? The `/reset` command makes it easy to remove all records and tables from your database and start fresh within a couple seconds! No need to make a whole new database server. + +### Usage + +`/reset` + +- There are no parameters for this command. + +![reset_cmd](assets/reset_cmd.gif) + +## Viewing a Table + +The `/table` command displays a table in a nicely formatted rich embed, with the columns as field titles, and the records from those columns as the field descriptions. The data is numbered so that you can easily correlate each record with its group. + +### Usage + +`/table [name]` + +- **Name:** Name of the table. + +### Example + +`/table games` + +![table_cmd](assets/table_cmd.gif) + +!!! warning "Limitation" + + - Considering the limit of fields is 25 on discord. The command can only show up to 25 columns, so we'll signify the limit as `field_length = 25` forming the following inequality: **C** <= `field_length` where **C** is the number of columns. + - Field titles have a character limit of 256 characters, therefore some titles may be cut off with an ellipsis at the end. + +## Viewing a Column + +The `/column` command displays a column from a page in a neat, paginated rich embed to visualize the column data. + +### Usage + +`/column [table] [name]` + +- **Table:** The table the column belongs to. +- **Name:** Name of the column. + +### Example + +`/column games name` + +![column_cmd](assets/column_cmd.gif) + +!!! warning "Limitation" + + Embed descriptions have a character limit of `4096`, so some particularly large data may not fit within the bounds. diff --git a/spunky-sputniks/docs/index.md b/spunky-sputniks/docs/index.md new file mode 100644 index 0000000..f39eaf0 --- /dev/null +++ b/spunky-sputniks/docs/index.md @@ -0,0 +1,75 @@ +--- +hide: + - navigation +--- + +
+ discobase logo +
+ +## What is Discobase? + +**Python Discord Codejam 2024** + +This year, the theme was "information overload." We took that to heart, and made a database library that turns Discord into a database through various algorithms, and wrote a library to interact with it, either programatically or through a Discord interface, as well as another bot to show the library off. Truly, we're overloading a Discord server with _lots_ of information. + +We used [discord.py](https://discordpy.readthedocs.io/) to interact with Discord and turn it into a data store, and used [Pydantic](https://docs.pydantic.dev/) for serializing database models. + +## Features + +- Pure Python, and pure Discord. +- Asynchronous. +- Fully type safe. + +## Installation + +### Stable + +Install the stable version of `discobase` using this commit: + +``` +$ pip install git+https://github.com/ZeroIntensity/discobase@e7604673d136d2eefcf727ef9326974a2ecc22ff +``` + +You can also install the latest version: + +``` +$ pip install discobase +``` + +!!! bug + + The stable version includes the admin commands for your database, but lacks <3.11 support, while the latest version is the opposite, as it has down to 3.8 support, but lacks admin commands. This is due to a last-minute oversight on our part, but there is nothing we can do at this point. + +## Quickstart + +```py +import discobase +import asyncio + +db = discobase.Database("My discord database") + +@db.table +class User(discobase.Table): + name: str + password: str + +async def main(): + async with db.conn("My bot token"): + ... + +asyncio.run(main()) +``` + +## Contributions + +Per the presentation requirements, here's what each team member contributed: + +- Everyone: Laid out the concepts for the core implementation and how it would work. You can see [this issue](https://github.com/ZeroIntensity/discobase/issues/4) for the discussion. +- [Zero](https://github.com/zerointensity/) and [Rubiks](https://github.com/Rubiks14): Implemented the core library functionality. +- [Skye](https://github.com/enskyeing) and [Gimpy](https://github.com/Gimpy3887): Built all the admin commands based on the core library. +- [Rubiks](https://github.com/Rubiks14): Wrote the demo bot as shown in the [demonstration section](https://discobase.zintensity.dev/demonstration/). + +## Copyright + +`discobase` is distributed under the MIT license. diff --git a/spunky-sputniks/docs/library.md b/spunky-sputniks/docs/library.md new file mode 100644 index 0000000..6cc4085 --- /dev/null +++ b/spunky-sputniks/docs/library.md @@ -0,0 +1,365 @@ +--- +hide: + - navigation +--- + +# Core Library + +## Introduction + +Discobase turns a Discord bot into a database manager, through a server of `name`. The top-level class for Discobase is `Database`: + +```py +import discobase + +db = discobase.Database("My discord database") +``` + +Internally, this would create a server called "My discord database," and then use that for all storage. If this server already exists, it simply uses the existing server. + +### Logging + +By default, logging in Discobase is disabled. The `Database()` constructor has a `logging` parameter that you can pass to enable logging: + +```py +import discobase + +db = discobase.Database("My discord database", logging=True) +``` + +However, this only enables the Discobase logging, it does _not_ enable the logging for [discord.py](https://discordpy.readthedocs.io/en/latest/) (which is also disabled by default). If you would like to enable that, use their [setup_logging](https://discordpy.readthedocs.io/en/latest/api.html?highlight=setup_logging#discord.utils.setup_logging) function. + +!!! note + + Note that Discobase *does not* use Python's built-in logging library. Instead, it uses [loguru](https://loguru.readthedocs.io/en/stable/). + +## Logging in + +It's worth noting that the `Database` constructor itself doesn't actually initialize the server. If we want to do anything useful, we need to log in — _that's_ when the server gets initialized. + +There are a few methods to log in, that depend on your use case, the simplest being `login()`: + +```py +import discobase +import asyncio + +db = discobase.Database("My discord database") + +async def main(): + await db.login("My bot token...") + +if __name__ == "__main__": + asyncio.run(main()) +``` + +`login()` has a bit of a pitfall: it blocks the function from proceeding (as in, the `await` never finishes, at least without some magic). In that case, you have two options: `login_task()` and `conn()`. Let's start with `login_task`, which runs the connection in the background as a free-flying task. + +!!! note + + `login_task()` stores a reference to the task internally to prevent it from being accidentially deallocated while running, this is what we mean by "free-flying." + +For example: + +```py +import discobase +import asyncio + +db = discobase.Database("My discord database") + +async def main(): + db.login_task("My bot token...") + # Do something else, the bot is now running in the background + +if __name__ == "__main__": + asyncio.run(main()) +``` + +Notice the lack of an `await` before `db.login_task()` — that's intentional, and we'll talk about that more in a moment. + +!!! warning + + A Discobase bot should generally *only* be used for a database, and not anything else. If you want to use Discobase in your own Discord bot, use two bot tokens: one for Discobase, and one for your bot. + +### Waiting Until Ready + +After calling `login_task()`, there isn't really a guarantee that the database is connected, which can cause some odd "it works on my machine" problems. To ensure that you're good to go, you should call `wait_ready()`: + +```py +import discobase +import asyncio + +db = discobase.Database("My discord database") + +async def main(): + db.login_task("My bot token...") + await db.wait_ready() + # We can now safely use the database! + +if __name__ == "__main__": + asyncio.run(main()) +``` + +Note that while the `asyncio.Task` object returned by `login_task()` is "free-flying," it does _not_ force the event loop to stay open indefinitely. To keep the connection alive, you must `await` the task: + +```py +import discobase +import asyncio + +db = discobase.Database("My discord database") + +async def main(): + task = db.login_task("My bot token...") + await db.wait_ready() + # We can now safely use the database! + # ... + await task # Keeps the connection open + +if __name__ == "__main__": + asyncio.run(main()) +``` + +`login_task()` and `wait_ready()` might suffice, depending on your application, but in many cases you might want to connect and disconnect, without running for the lifetime of the program. + +For this use case, instead of just `login_task()` followed by `wait_ready()`, you want to use `conn()`, which is an [asynchronous context manager](https://docs.python.org/3/reference/datamodel.html#async-context-managers). This method calls `wait_ready()` for you, so you assume that the database is connected while in the `async with` block: + +```py +import discobase +import asyncio + +db = discobase.Database("My discord database") + +async def main(): + async with db.conn("My bot token..."): + # Do something with the database + +if __name__ == "__main__": + asyncio.run(main()) +``` + +## Tables + +Now that your database is ready to go, let's make a table. Discobase uses [Pydantic](https://docs.pydantic.dev/latest/) to define schemas, through the `discobase.Table` type, which is, more or less, a drop in for `pydantic.BaseModel`: + +```py +import discobase + +db = discobase.Database("My discord database") + +class User(discobase.Table): + name: str + password: str +``` + +!!! note + + Throughout this documentation, we'll refer to a type that inherits from `discobase.Table` (and incidentially is also decorated with `table()`) as a "schema" or something similar. + +However, we forgot something in the above example! `discobase.Table` only allows use of `User` as a schema, but the database still needs to know that it exists. We can do this via the `table()` decorator: + +```py +import discobase + +db = discobase.Database("My discord database") + +@db.table +class User(discobase.Table): + name: str + password: str +``` + +!!! warning + + It is not allowed to have multiple tables of the same name (*i.e.*, the name of the class). For example, the following will **not** work: + + ```py + + import discobase + + db = discobase.Database("My discord database") + + @db.table + class User(discobase.Table): + name: str + password: str + + + @db.table + class User(discobase.Table): + some_other_field: str + ``` + +Great, now `User` is visible to our `Database` object! + +### Late Tables + +At first glance, it may look like `@db.table()` will set everything up for you — this is not the case. In fact, `@db.table()` simply sets a few attributes, but the key is that it _marks_ the type as a schema. We can't do any actual initialization until the bot is logged in, so initialization happens _then_. + +For example, the following would cause some errors if we try to use the table, since we use our table after the bot has already been initialized: + +```py +import discobase +import asyncio + +db = discobase.Database("My discord database") + +async def main(): + async with db.conn("My bot token..."): + # By default, this is not allowed! + @db.table + class User(discobase.Table): + name: str + password: str + + # ... + +if __name__ == "__main__": + asyncio.run(main()) +``` + +OK, so what's the fix? The `table()` decorator still _marks_ the `User` type as part of the database in the above example, so all we need to do is tell the database to do it's table construction a second time — we can do this through the `build_tables()` method. Our fixed version of the example above would look like: + +```py +import discobase +import asyncio + +db = discobase.Database("My discord database") + +async def main(): + async with db.conn("My bot token..."): + @db.table + class User(discobase.Table): + name: str + password: str + + await db.build_tables() # Initialize `User` internally + # Using `User` is now OK! + +if __name__ == "__main__": + asyncio.run(main()) +``` + +!!! question "Why not call `build_tables()` automatically in `table()`?" + + Initializing is an *asynchronous* operation, and `table()` is not an asynchronous function. + We'd have to do lots of weird event loop hacks to get it to work this way. + +## Saving + +Now, let's write to the database — we can do this via calling `save()` on an instance of our schema type: + +```py +import discobase +import asyncio + +db = discobase.Database("My discord database") + +@db.table +class User(discobase.Table): + name: str + password: str + +async def main(): + async with db.conn("My bot token..."): + user = User(name="Peter", password="foobar") + await user.save() # Saves this record to the database + +if __name__ == "__main__": + asyncio.run(main()) +``` + +Note that in the above, we used `await` with `save()`. This isn't actually required, since `save()` returns a `Task`, not a coroutine! In many cases, you don't need to save the record right then and there, and you can run it as a background task. This is especially important when it comes to Discobase — the ratelimit can make saving very slow, so it might be useful to save in the background and not block the current function. For example, if you were to use Discobase in a web application: + +```py +@app.get("/signup") +def signup(username: str, password: str): + User(name=username, password=password).save() # This is launched as a background task! + return "..." +``` + +If we were to `await` the result of `save()` above, + +## Querying + +We can look up an instance of it via `find()` (or `find_unique()`, if you want a unique database entry): + +```py +async def main(): + async with db.conn("My bot token..."): + users = await User.find(name="Peter") + for user in users: + print(f"Name: {user.name}, password: {user.password}") +``` + +Note that this works in a whitelist manner — as in, we search for values in the query, not get everything and exclude those that don't match it. However, calling `find()` with nothing is a special case that gets every entry in the table (note that this is a slow operation). + +### Unique Entries + +As mentioned above, you can also use `find_unique()` to get a unique entry: + +```py +async def main(): + async with db.conn("My bot token..."): + peter = await User.find_unique(name="Peter") +``` + +By default, `find_unique` is set to strict mode, which ensures the following: + +- The instance actually exists, and an exception is raised if the record wasn't found (_i.e._, `find_unique()` cannot return `None` when strict mode is enabled.) +- Only one of the entry was found. If strict mode is disabled and multiple entries are found, the first entry is returned. Otherwise, an exception is thrown. + +!!! info + + This is type safe through `@typing.overload()` — if you pass `strict=True`, the signature of `find_unique()` will not hint a return value that can be `None`. + +## Updating + +It's worth noting that `save()` can only be used on non-saved instances — as in, they haven't had `save()` called on them already. Instances created by their constructor manually (for example, calling `User(...)` above) are _not_ saved, while objects returned by something like `find` are considered to be saved, as they are already in the database. + +So what about when you want to update an existing record? For that, you should call `update()`, which updates an existing record in-place. For example, if you wanted to change the record from the previous example: + +```py +async def main(): + async with db.conn("My bot token..."): + peter = await User.find_unique(name="Peter") + peter.password = "barfoo" + await peter.update() +``` + +Note that just like `save()`, `update()` returns a `Task`, meaning you can omit the `await` if you would like to perform the operation as a background task. + +```py +peter = await User.find_unique(name="Peter") +peter.password = "barfoo" +peter.update() # Run this in the background +``` + +### Deleting + +You can also delete a saved record via the `delete()` method: + +```py +async def main(): + async with db.conn("My bot token..."): + peter = await User.find_unique(name="Peter") + peter.delete() +``` + +Per `update()` and `save()`, this returns a `Task` that can be awaited or ran in the background. + +### Committing + +As you might have guessed, `update()` is the inverse of `save()`, in the sense that it only works for _saved_ objects. But what if you don't know if the object is saved or not? Technically speaking, you could check if the `__disco_id__` attribute is `-1` (e.g. `saved = peter.__disco_id__ != -1`), but that's not too convenient. + +Instead, you can use `commit()`, which does this check for you. `commit()` works for _both_ saved and non-saved objects, and also can be used as a background task: + +```py +peter = (await User.find_unique(name="Peter", strict=False)) or User(name="Peter", password="foobar") +peter.password = "barfoo" +peter.commit() # Works with both cases! +``` + +## Admin Commands + +Discobase comes with a set of admin commands to interact with your database right from Discord. First, you'll need to join the server, which is printed in the logs (see above on how to enable logging.) + +Once you've joined, you're ready to try the admin commands! See the next section on what commands exist. diff --git a/spunky-sputniks/docs/reference.md b/spunky-sputniks/docs/reference.md new file mode 100644 index 0000000..de1bf1f --- /dev/null +++ b/spunky-sputniks/docs/reference.md @@ -0,0 +1,10 @@ +--- +hide: + - navigation +--- + +# API Reference + +::: discobase.database +::: discobase.table +::: discobase.exceptions diff --git a/spunky-sputniks/hatch.toml b/spunky-sputniks/hatch.toml new file mode 100644 index 0000000..f55fbfd --- /dev/null +++ b/spunky-sputniks/hatch.toml @@ -0,0 +1,29 @@ +[version] +path = "src/discobase/__about__.py" + +[build.targets.sdist] +only-include = ["src/discobase/"] + +[build.targets.wheel] +packages = ["src/discobase/"] + +[envs.hatch-test] +installer = "pip" +dependencies = [ + "coverage[toml]~=7.4", + "pytest~=8.1", + "pytest-asyncio~=0.23" +] + +[envs.hatch-test.scripts] +run = "pytest{env:HATCH_TEST_ARGS:} {args} -s" +run-cov = "coverage run -m pytest{env:HATCH_TEST_ARGS:} {args}" +cov-combine = "coverage combine" +cov-report = "coverage report" + +[envs.docs] +dependencies = ["mkdocs", "mkdocstrings[python]", "mkdocs-material"] + +[envs.docs.scripts] +build = "mkdocs build --clean" +serve = "mkdocs serve" diff --git a/spunky-sputniks/mkdocs.yml b/spunky-sputniks/mkdocs.yml new file mode 100644 index 0000000..83200e2 --- /dev/null +++ b/spunky-sputniks/mkdocs.yml @@ -0,0 +1,91 @@ +site_name: Discobase +site_url: https://discobase.zintensity.dev +repo_url: https://github.com/ZeroIntensity/discobase +repo_name: ZeroIntensity/discobase + +nav: + - Home: index.md + - Core Library: library.md + - Discord Interface: discord_interface.md + - Demonstration: demonstration.md + - API Reference: reference.md + +theme: + name: material + logo: assets/discobase_white.png + palette: + - media: "(prefers-color-scheme)" + primary: black + accent: black + toggle: + icon: material/brightness-auto + name: Switch to light mode + + # Palette toggle for light mode + - media: "(prefers-color-scheme: light)" + scheme: default + primary: black + accent: black + toggle: + icon: material/brightness-7 + name: Switch to dark mode + + # Palette toggle for dark mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + primary: black + accent: black + toggle: + icon: material/brightness-4 + name: Switch to system preference + features: + - content.tabs.link + - content.code.copy + - content.action.edit + - search.highlight + - search.share + - search.suggest + - navigation.footer + - navigation.indexes + - navigation.sections + - navigation.tabs + - navigation.tabs.sticky + - navigation.top + - toc.follow + + icon: + repo: fontawesome/brands/github + +markdown_extensions: + - attr_list + - md_in_html + - toc: + permalink: true + - pymdownx.highlight: + anchor_linenums: true + - pymdownx.inlinehilite + - pymdownx.superfences + - pymdownx.snippets + - admonition + - pymdownx.details + - pymdownx.tabbed: + alternate_style: true + +plugins: + - search + - mkdocstrings: + handlers: + python: + paths: [src] + options: + show_root_heading: true + show_object_full_path: false + show_symbol_type_heading: true + show_symbol_type_toc: true + show_signature: true + seperate_signature: true + show_signature_annotations: true + signature_crossrefs: true + show_source: true + show_if_no_docstring: true + show_docstring_examples: true diff --git a/spunky-sputniks/netlify.toml b/spunky-sputniks/netlify.toml new file mode 100644 index 0000000..37d891f --- /dev/null +++ b/spunky-sputniks/netlify.toml @@ -0,0 +1,3 @@ +[build] +command = "hatch run docs:build" +publish = "site" diff --git a/spunky-sputniks/pyproject.toml b/spunky-sputniks/pyproject.toml new file mode 100644 index 0000000..8ac360a --- /dev/null +++ b/spunky-sputniks/pyproject.toml @@ -0,0 +1,20 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "discobase" +description = "Database library using nothing but Discord. PyDis Codejam 2024." +readme = "README.md" +license = "MIT" +dependencies = ["discord.py", "pydantic", "typing_extensions", "loguru", "aiocache"] +dynamic = ["version"] + +[tool.ruff] +line-length = 79 # PEP 8 + +[tool.ruff.lint] +ignore = ["F403"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] diff --git a/spunky-sputniks/requirements.txt b/spunky-sputniks/requirements.txt new file mode 100644 index 0000000..cf2aae0 --- /dev/null +++ b/spunky-sputniks/requirements.txt @@ -0,0 +1,3 @@ +# Note: this file is *only* for Netlify's build! +# Project dependencies are stored in `pyproject.toml` +hatch diff --git a/spunky-sputniks/runtime.txt b/spunky-sputniks/runtime.txt new file mode 100644 index 0000000..cc1923a --- /dev/null +++ b/spunky-sputniks/runtime.txt @@ -0,0 +1 @@ +3.8 diff --git a/spunky-sputniks/src/demo/__about__.py b/spunky-sputniks/src/demo/__about__.py new file mode 100644 index 0000000..fdb23aa --- /dev/null +++ b/spunky-sputniks/src/demo/__about__.py @@ -0,0 +1,2 @@ +__version__ = "0.0.0-dev0" +__license__ = "MIT" diff --git a/spunky-sputniks/src/demo/__init__.py b/spunky-sputniks/src/demo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/spunky-sputniks/src/demo/bookmark_ui.py b/spunky-sputniks/src/demo/bookmark_ui.py new file mode 100644 index 0000000..42e6186 --- /dev/null +++ b/spunky-sputniks/src/demo/bookmark_ui.py @@ -0,0 +1,151 @@ +import db_interactions +import discord +import models + + +async def send_bookmark(interaction: discord.Interaction, record: models.BookmarkedMessage): + embed = build_bookmark_embed(record=record) + await interaction.followup.send(content="Successfully bookmarked the message", embed=embed, ephemeral=True) + +class BookmarkForm(discord.ui.Modal): + """The form where a user can fill in a custom title for their bookmark & submit it.""" + + bookmark_title = discord.ui.TextInput( + label="Choose a title for your bookmark (optional)", + placeholder="Type your bookmark title here", + default="Bookmark", + max_length=50, + min_length=0, + required=False, + ) + + def __init__(self, message: discord.Message): + super().__init__(timeout=1000, title="Name your bookmark") + self.message = message + + async def on_submit(self, interaction: discord.Interaction) -> None: + """Sends the bookmark embed to the user with the newly chosen title.""" + title = self.bookmark_title.value or self.bookmark_title.default + await interaction.response.defer(ephemeral=True) + record = await db_interactions.add(interaction, self.message, title) + await send_bookmark(interaction, record) + + +def build_bookmark_embed(record: models.BookmarkedMessage): + embed = discord.Embed(title=record.title, description=record.message_content, colour=0x68C290) + embed.set_author( + name=record.author_name, + icon_url=record.author_avatar_url + ) + return embed + +def build_embeds_list(records: list[models.BookmarkedMessage]) -> list[discord.Embed]: + embeds: list[discord.Embed] = [] + for record in records: + embed = build_bookmark_embed(record) + embeds.append(embed) + return embeds + + +def build_error_embed(embed_content: str): + embed = discord.Embed(title="Error Saving embed", description=embed_content) + embed.set_author(name="Bookmark Bot") + return embed + + +class ArrowButtons(discord.ui.View): + def __init__(self, records: list[models.BookmarkedMessage]) -> None: + super().__init__(timeout=None) + self.records = records + self.content = build_embeds_list(records) + self.position = 0 + self.pages = len(self.content) + self.on_ready() + + @discord.ui.button(label='⬅️', style=discord.ButtonStyle.primary, custom_id='l_button') + async def back(self, interaction: discord.Interaction, button: discord.ui.Button) -> None: + """Controls the left button on the qotd list embed""" + # move back a position in the embed list + self.position -= 1 + + # check if we're on the first page, then disable the button to go left if we are (cant go anymore left) + if self.position == 0: + button.disabled = True + + # set the right button to a variable + right_button = [x for x in self.children if x.custom_id == 'r_button'][0] + + # check if we're not on the last page, if yes then enable right button + if not self.position == self.pages - 1: + right_button.disabled = False + + # update discord message + await interaction.response.edit_message(embed=self.content[self.position], view=self) + + @discord.ui.button(label='➡️️️', style=discord.ButtonStyle.primary, custom_id='r_button') + async def forward(self, interaction: discord.Interaction, button: discord.ui.Button) -> None: + """Controls the right button on the qotd list embed""" + # move forward a position in the embed list + self.position += 1 + + # set a variable for left button + left_button = [x for x in self.children if x.custom_id == 'l_button'][0] + # check if we're not on the first page, if yes then enable left button + if not self.position == 0: + left_button.disabled = False + + # check if we're on the last page, if yes then disable right button + if self.position == self.pages - 1: + button.disabled = True + + # update discord message + await interaction.response.edit_message(embed=self.content[self.position], view=self) + + @discord.ui.button(label='🗑', style=discord.ButtonStyle.danger, custom_id='del_button') + async def delete(self, interaction: discord.Interaction, _: discord.ui.Button) -> None: + """Controls the delete button on the qotd list embed""" + + # remove the entry from the database + await db_interactions.remove(self.records[self.position]) # This is causing some errors to check in the morning + self.records.remove(self.records[self.position]) + + # remove the embed from the message + self.content.remove(self.content[self.position]) + self.pages = len(self.content) + + # Only change position if the deleted item is not the first one + if self.position == 0: + self.position = self.position + else: + self.position -= 1 + + # set a variable for left button + left_button: discord.Button = [x for x in self.children if x.custom_id == 'l_button'][0] + # check if we're not on the first page, if yes then enable left button + if self.position == 0: + left_button.disabled = True + + # set the right button to a variable + right_button: discord.Button = [x for x in self.children if x.custom_id == 'r_button'][0] + # check if we're not on the last page, if yes then enable right button + if self.position == self.pages - 1: + right_button.disabled = True + + # Edit the message if there is still data. otherwise delete it. + if self.pages > 0: + await interaction.response.edit_message(embed=self.content[self.position], view=self) + else: + await interaction.response.edit_message(content="You have no more saved bookmarks", embed=None, view=None) + + def on_ready(self) -> None: + """Checks the number of pages to decide which buttons to have enabled/disabled""" + left_button = [x for x in self.children if x.custom_id == 'l_button'][0] + right_button = [x for x in self.children if x.custom_id == 'r_button'][0] + + # if we only have one page, disable both buttons + if self.pages == 1: + left_button.disabled = True + right_button.disabled = True + # if we have more than one page, only disable the left button for the first page + else: + left_button.disabled = True diff --git a/spunky-sputniks/src/demo/db_interactions.py b/spunky-sputniks/src/demo/db_interactions.py new file mode 100644 index 0000000..82881c9 --- /dev/null +++ b/spunky-sputniks/src/demo/db_interactions.py @@ -0,0 +1,49 @@ +import discord +import models +from demobot_config import default_icon + +import discobase + + +async def add(interaction: discord.Interaction, message: discord.Message, title: str) -> models.BookmarkedMessage: + """Add a message to the bookmarks. + + Args: + interaction: The `discord.Interaction` that initiated the command + message: The `discord.Message` that is being bookmarked + title: The title provided by the modal + Returns: + models.BookmarkedMessage: the record that was saved to the database + """ + + avatar_url = message.author.display_avatar.url if message.author.display_avatar is not None else default_icon + record = models.BookmarkedMessage( + user_id=interaction.user.id, + title=title, + author_name=message.author.name, + author_avatar_url=avatar_url, + message_content=message.content + ) + await record.save() + return record + +async def get(db: discobase.Database, interaction: discord.Interaction) -> list[models.BookmarkedMessage]: + """Get bookmarks for a user, or across the whole server. If getting bookmarks for the whole sever, a search string is required. + + Args: + db: Discobase database instance + interaction: discord interaction that triggered the function + + Returns: + list[models.BookmarkedMessage]: the list of bookmarks saved by the user + """ + return await db.tables[models.BookmarkedMessage.__name__.lower()].find(user_id = interaction.user.id) + +async def remove(record: models.BookmarkedMessage) -> None: + """Remove a bookmark from the list. + + Args: + db: discobase database instance. + record: the record to delete + """ + await record.delete() diff --git a/spunky-sputniks/src/demo/demobot_commands.py b/spunky-sputniks/src/demo/demobot_commands.py new file mode 100644 index 0000000..1842e66 --- /dev/null +++ b/spunky-sputniks/src/demo/demobot_commands.py @@ -0,0 +1,40 @@ +import bookmark_ui +import db_interactions +import discord +from demobot_config import db +from discord.ext import commands + + +class Bookmark(discord.app_commands.Group): + def __init__(self, bot: commands.Bot): + super().__init__(name="bookmark") + self.bot = bot + self.bookmark_context_menu = discord.app_commands.ContextMenu(name="Bookmark", callback=self.bookmark_message_callback) + self.bot.tree.add_command(self.bookmark_context_menu) + + async def bookmark_message_callback(self, interaction: discord.Interaction, message: discord.Message) -> None: + """ + The callback in charge of creating a bookmark when the context menu is selected + + Args: + interaction: discord.Interaction that triggered the save + message: discord.Message that is being saved + """ + bookmark_form = bookmark_ui.BookmarkForm(message=message) + await interaction.response.send_modal(bookmark_form) + + @discord.app_commands.command(name="get_bookmarks", description="Retrieve all of your bookmarks") + async def get_bookmarks(self, interaction: discord.Interaction) -> None: + """ + Creates a message with all of the user's bookmarks that can be flipped through and deleted + + Args: + interaction: discord.Interaction that triggered the save + """ + await interaction.response.defer(ephemeral=True, thinking=True) + records = await db_interactions.get(db, interaction) + if len(records) == 0: + await interaction.followup.send("You have not bookmarked any messages") + else: + buttons = bookmark_ui.ArrowButtons(records=records) + await interaction.followup.send(view=buttons, embed=buttons.content[0], ephemeral=True) diff --git a/spunky-sputniks/src/demo/demobot_config.py b/spunky-sputniks/src/demo/demobot_config.py new file mode 100644 index 0000000..0060eba --- /dev/null +++ b/spunky-sputniks/src/demo/demobot_config.py @@ -0,0 +1,5 @@ +import discobase + +db = discobase.Database("personal_discobase_server") + +default_icon = "https://i.imgur.com/2QH3tEQ.png" diff --git a/spunky-sputniks/src/demo/main.py b/spunky-sputniks/src/demo/main.py new file mode 100644 index 0000000..3abac2b --- /dev/null +++ b/spunky-sputniks/src/demo/main.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import asyncio +import os + +import demobot_commands +import discord +from demobot_config import db +from loguru import logger + + +class BookmarkBot(discord.Client): + def __init__(self): + super().__init__(intents=discord.Intents.all(), command_prefix="!") + self.tree = discord.app_commands.CommandTree(self) + self.tree.add_command(demobot_commands.Bookmark(self)) + + @logger.catch(reraise=True) + async def on_ready(self) -> None: + try: + await self.tree.sync() + logger.info(f"Logged in as {self.user}") + logger.debug(f"{self.tree.client}") + logger.debug(f"Loaded the following commands: {await self.tree.fetch_commands()}") + except Exception as e: + print(f"{e.__class__.__name__}: {e}") + + async def on_error(self, event_method: str, /, *args: asyncio.Any, **kwargs: asyncio.Any) -> None: + return await super().on_error(event_method, *args, **kwargs) + +discord.utils.setup_logging() +bot = BookmarkBot() + +async def main() -> None: + async with db.conn(os.getenv("DB_BOT_TOKEN")): + try: + await bot.start(os.getenv("BOOKMARK_BOT_TOKEN")) + finally: + await bot.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/spunky-sputniks/src/demo/models.py b/spunky-sputniks/src/demo/models.py new file mode 100644 index 0000000..af3837b --- /dev/null +++ b/spunky-sputniks/src/demo/models.py @@ -0,0 +1,12 @@ +from demobot_config import db + +import discobase + + +@db.table +class BookmarkedMessage(discobase.Table): + user_id: int + title: str + author_name: str + author_avatar_url: str + message_content: str diff --git a/spunky-sputniks/src/demo/py.typed b/spunky-sputniks/src/demo/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/spunky-sputniks/src/discobase/__about__.py b/spunky-sputniks/src/discobase/__about__.py new file mode 100644 index 0000000..dbee0ed --- /dev/null +++ b/spunky-sputniks/src/discobase/__about__.py @@ -0,0 +1,2 @@ +__version__ = "1.0.0" +__license__ = "MIT" diff --git a/spunky-sputniks/src/discobase/__init__.py b/spunky-sputniks/src/discobase/__init__.py new file mode 100644 index 0000000..3d8a05b --- /dev/null +++ b/spunky-sputniks/src/discobase/__init__.py @@ -0,0 +1,10 @@ +""" +Discobase - Relation database library using nothing but Discord. + +Python Discord Codejam 2024 +""" + +from .__about__ import __license__, __version__ +from .database import * +from .exceptions import * +from .table import * diff --git a/spunky-sputniks/src/discobase/_cursor.py b/spunky-sputniks/src/discobase/_cursor.py new file mode 100644 index 0000000..9f85ae8 --- /dev/null +++ b/spunky-sputniks/src/discobase/_cursor.py @@ -0,0 +1,1012 @@ +from __future__ import annotations + +import asyncio +import hashlib +from base64 import urlsafe_b64decode, urlsafe_b64encode +from collections.abc import Iterable +from datetime import timedelta +from functools import lru_cache +from typing import TYPE_CHECKING, Any, Callable, List, Optional + +import discord +from aiocache import cached +from discord.utils import snowflake_time, time_snowflake +from loguru import logger +from pydantic import BaseModel, ValidationError + +from ._metadata import Metadata +from .exceptions import (DatabaseCorruptionError, DatabaseLookupError, + DatabaseStorageError) + +if TYPE_CHECKING: + from .table import Table + +__all__ = ("TableCursor",) + + +class _Record(BaseModel): + content: str + """Base64 encoded Pydantic model dump of the record.""" + + @classmethod + def from_data(cls, data: Table) -> _Record: + logger.debug(f"Generating a _Record from data: {data}") + return _Record( + content=urlsafe_b64encode( # Record JSON data is stored in base64 + data.model_dump_json().encode("utf-8"), + ).decode("utf-8"), + ) + + def decode_content(self, record: Table | type[Table]) -> Table: + return record.model_validate_json(urlsafe_b64decode(self.content)) + + +class _IndexableRecord(BaseModel): + key: int + """Hashed value of the key.""" + record_ids: List[int] + """Message IDs of the records that correspond to this key.""" + next_value: Optional[_IndexableRecord] = None + """ + Temporary placeholder value for the next entry. + Only for use in resizing. + """ + + @classmethod + def from_message(cls, message: str) -> _IndexableRecord | None: + """ + Generate an `_IndexableRecord` instance from message content. + + Args: + message: Message content to parse as JSON. + + Returns: + _IndexableRecord | None: An `_IndexableRecord` instance, or `None`, + if the message was `null`. + """ + logger.debug(f"Parsing {message} into an _IndexableRecord") + try: + return ( + cls.model_validate_json(message) if message != "null" else None + ) + except ValidationError as e: + raise DatabaseCorruptionError( + f"got bad _IndexableRecord entry: {message}" + ) from e + + +class _HashTransport: + """ + Hacky object to use `hash()` for tuples and dictionaries + that retains the value between interpreters. + """ + + def __init__(self, hash_num: int) -> None: + self.hash_num = hash_num + + def __hash__(self) -> int: + return self.hash_num + + +class TableCursor: + def __init__( + self, + metadata: Metadata, + metadata_channel: discord.TextChannel, + guild: discord.Guild, + ) -> None: + self.metadata = metadata + self.metadata_channel = metadata_channel + self.guild = guild + + @lru_cache + def _find_channel(self, channel_id: int) -> discord.TextChannel: + for channel in self.guild.channels: + if channel.id != channel_id: + continue + + if not isinstance(channel, discord.TextChannel): + raise DatabaseCorruptionError( + f"{channel!r} is not a TextChannel" + ) + + return channel + + raise DatabaseCorruptionError( + f"could not find channel with id {channel_id}" + ) + + async def _find_collision_message( + self, + channel: discord.TextChannel, + index: int, + *, + search_func: Callable[[str], bool] = lambda s: s == "null", + ) -> discord.Message: + """ + Search for a message via a worst-case O(n) search in the event + of a hash collision. + + Args: + channel: Index channel to search. + index: The index to start at. + search_func: Function to check if the message content is good. + + Returns: + discord.Message: The message that satisfies search_func + """ + logger.debug( + f"Looking up hash collision entry using search function: {search_func}" # noqa + ) + offset: int = index + while True: + if (offset + 1) >= self.metadata.max_records: + logger.debug("We need to wrap around the table.") + offset = 0 + else: + offset += 1 + + if offset == index: + raise DatabaseCorruptionError( + f"index channel {channel!r} has no free messages, table was likely not resized." # noqa + ) + + message = await self._lookup_message( + channel, + offset, + ) + logger.debug( + f"Hash collision search at index: {offset} {message=}", + ) + if search_func(message.content): + logger.debug( + f"Done searching for collision message: {message.content}" + ) + return message + + async def _edit_message( + self, + channel: discord.TextChannel, + mid: int, + content: str, + ) -> None: + """ + Edit a message given the channel, message ID, and content. + + This should *not* be used over `discord.Message.edit`, it's simply + a handy utility to use when you only have the message ID. + """ + editable_message = await channel.fetch_message(mid) + logger.debug(f"Editing message (ID {mid}) to {content}") + await editable_message.edit(content=content) + + def _to_index(self, value: int) -> int: + """ + Generate an index from a hash number based + on the metadata's `max_records`. + + Args: + value: Integer hash, positive or negative. + + Returns: + int: Index in range of `metadata.max_records`. + """ + index = (value & 0x7FFFFFFF) % self.metadata.max_records + logger.debug( + f"Hashed value {value} turned into index: {index} ({self.metadata.max_records=})" # noqa + ) + return index + + @lru_cache() + def _hash( + self, + value: Any, + ) -> int: + """ + Hash the field into an integer. + + Args: + value: Any discobase-hashable object. + + Returns: + int: An integer, positive or negative, representing a unique hash. + This is always the same thing across programs. + """ + logger.debug(f"Hashing object: {value!r}") + if isinstance(value, str): + hashed_str = int( + hashlib.sha1(value.encode("utf-8")).hexdigest(), + 16, + ) + logger.debug(f"Hashed string {value!r} into {hashed_str}") + return hashed_str + elif isinstance(value, dict): + transport: dict[_HashTransport, _HashTransport] = {} + + for k, v in value.items(): + transport[_HashTransport(self._hash(k))] = _HashTransport( + self._hash(v) + ) + + hashed_dict = hash(transport) + logger.debug(f"Hashed dictionary {value!r} into {hashed_dict}") + return hashed_dict + elif isinstance(value, Iterable): + hashes: list[_HashTransport] = [] + for item in value: + hashes.append(_HashTransport(self._hash(item))) + + hashed_tuple = hash(tuple(hashes)) + logger.debug(f"Hashed iterable {value!r} into {hashed_tuple}") + return hashed_tuple + elif isinstance(value, int): + return value + else: + raise DatabaseStorageError(f"unhashable: {value!r}") + + def _as_hashed( + self, + value: Any, + ) -> tuple[int, int]: + """ + Get the hash number and index for `value`. + """ + hashed = self._hash(value) + return hashed, self._to_index(hashed) + + @cached() + async def _lookup_message_impl( + self, + channel: discord.TextChannel, + index: int, + ) -> discord.Message: + """ + The *implementation* of looking up a message by + it's index in the table. You need to call `fetch()` + on the result of this function due to caching. + + Args: + channel: Index channel to search. + index: Index to get. + + Returns: + discord.Message: The found message. + + Raises: + DatabaseCorruptionError: Could not find the index. + """ + metadata = self.metadata + logger.debug(f"Looking up message: {index}") + for timestamp, rng in metadata.time_table.items(): + start: int = rng[0] + end: int = rng[1] + if index not in range( + start, end + ): # Pydantic doesn't support ranges + continue + + logger.debug(f"In range: {start} - {end}") + current_index: int = 0 + async for msg in channel.history( + limit=end - start, + before=snowflake_time(timestamp), + ): + if current_index == (index - start): + logger.debug(f"{msg} found at index {current_index}") + return msg + current_index += 1 + + raise DatabaseCorruptionError( + f"range for {timestamp} in table {metadata.name} does not contain index {index}" # noqa + ) + + raise DatabaseCorruptionError( + f"message index out of range for table {metadata.name}: {index}" + ) + + async def _lookup_message( + self, + channel: discord.TextChannel, + index: int, + ) -> discord.Message: + """ + Lookup a message by it's index in the table. + + Args: + channel: Index channel to search. + index: Index to get. + + Returns: + discord.Message: The found message. + + Raises: + DatabaseCorruptionError: Could not find the index. + """ + # We need to refetch it for the latest content. + return await (await self._lookup_message_impl(channel, index)).fetch() + + async def _resize_hash( + self, + index_channel: discord.TextChannel, + amount: int, + ) -> int: + """ + Increases the hash for `index_channel` by amount + + Args: + index_channel: the channel that contains index data for a database + amount: the amount to increase the size by + + Returns: + int: snowflake representation of when the last message of the + resize was created + """ + last_message: discord.Message | None = None + + # Here be dragons: ratelimit makes gathering this actually worse. + for _ in range(amount): + last_message = await index_channel.send("null", silent=True) + + if not last_message: + raise DatabaseCorruptionError("last_message is None somehow") + # 5 seconds, per the Discord ratelimit + last_timestamp = timedelta(seconds=5) + last_message.created_at + return time_snowflake(last_timestamp) + + async def _resize_channel( + self, + channel: discord.TextChannel, + ) -> None: + """ + The implementation of resizing a channel. This method assumes + that `self.metadata.max_records` has already been doubled. + + This is meant for use in `gather()`, for optimal performance. + + Args: + channel: Index channel to resize. + """ + metadata = self.metadata + logger.debug( + f"Resizing channel: {channel!r} for table {metadata.name}", + ) + old_size: int = metadata.max_records // 2 + timestamp_snowflake = await self._resize_hash(channel, old_size) + rng = ( + old_size, + metadata.max_records, + ) + + for snowflake, time_range in metadata.time_table.copy().items(): + # We only want one time stamp for the range, this forces + # the latest one to always be used -- that's a good thing, + # we don't want to risk having messages from the previous range + # in this one. + if time_range == rng: + del metadata.time_table[snowflake] + + metadata.time_table[timestamp_snowflake] = rng + # Now, we have to move everything into the correct position. + # + # Note that this shouldn't put everything into memory, as + # each previous iteration will be freed -- this is good + # for scalability. + # + # Due to Discord's ratelimit, gathering the coros in this loop + # is actually a bad idea. + async for msg in channel.history( + limit=old_size, + oldest_first=True, + ): + # msg = await channel.fetch_message(msg.id) + record = _IndexableRecord.from_message(msg.content) + if not record: + continue + + new_index: int = self._to_index(record.key) + target = await self._lookup_message( + channel, + new_index, + ) + + next_record = _IndexableRecord.from_message(target.content) + inplace: bool = True + overwrite: bool = True + + if next_record: + if next_record.next_value: + logger.info("Hash collision in resize!") + target = await self._find_collision_message( + channel, + new_index, + ) + # `inplace` is True, so we fall + # through to the inplace edit. + # + # To be fair, I'm not too sure if this is + # the best approach, this might be worth + # refactoring in the future. + else: + logger.info("Updating record at the new index.") + inplace = False + logger.debug( + f"{next_record} marked as the next value location ({target.id=})" # noqa + ) + + if record.next_value: + record.next_value = None + # Here be dragons: if we overwrite the `next_value` + # with `None` to prevent a doubly-nested copy in + # the JSON, we have to mark this message to *not* + # be overwritten, otherwise we lose that data. + overwrite = False + + next_record.next_value = record + content = next_record.model_dump_json() + logger.debug(f"Editing {target.content} to {content}") + await target.edit(content=content) + + if inplace: + # In case of a hash collision, we want to mark + # this as having a `next_value`, so it doesn't get + # overwritten. + # + # We copy this to prevent a recursive model dump. + if record.next_value: + record.next_value = None + overwrite = False + + copy = record.model_copy() + copy.next_value = record + logger.info( + "Target index does not have an entry, updating in-place." # noqa + ) + content = copy.model_dump_json() + logger.debug(f"Editing in-place null to {content}") + assert target.content == "null" + await target.edit(content=content) + + # Technically speaking, the index could + # remain the same. We need to check for that. + if (not record.next_value) and (target != msg) and overwrite: + await msg.edit(content="null") + + # Finally, all the next_value attributes have been set, we can + # go through and update each record. + # + # The overall algorithm is O(2n), but it's much more scalable + # than trying to put the entire table into memory in order to + # resize it. + # + # This algorithm is pretty much infinitely scalable + # in terms of memory, but we're limited by Discord's ratelimit. + async for msg in channel.history( + limit=metadata.max_records, + oldest_first=True, + ): + record = _IndexableRecord.from_message(msg.content) + if not record: + continue + + logger.debug(f"Handling movement of {record!r}") + if not record.next_value: + raise DatabaseCorruptionError( + "all existing records after resize should have next_value", # noqa + ) + + if record.next_value.next_value: + raise DatabaseCorruptionError( + f"doubly nested next_value found: {record.next_value.next_value!r} in {record!r}" # noqa + ) + + content = record.next_value.model_dump_json() + logger.debug(f"Replacing {msg.content} with {content}") + await msg.edit(content=content) + + async def _resize_table(self) -> None: + """ + Resize all the index channels in a table. + """ + metadata = self.metadata + metadata.max_records *= 2 + logger.info( + f"Resizing table {metadata.name} to {metadata.max_records}" # noqa + ) + await asyncio.gather( + *[ + self._resize_channel(self._find_channel(cid)) + for cid in metadata.index_channels.values() + ] + ) + + # Dump the new metadata + await self._edit_message( + self.metadata_channel, + metadata.message_id, + metadata.model_dump_json(), + ) + logger.info( + f"Table {metadata.name} is now of size {metadata.max_records}" + ) + + async def _inc_records(self) -> None: + """ + Increment the `current_records` number on the + target metadata. This resizes the table if the maximum + size is reached. + """ + metadata = self.metadata + metadata.current_records += 1 + if metadata.current_records > metadata.max_records: + logger.info("The table is full! We need to resize it.") + await self._resize_table() + + await self._edit_message( + self.metadata_channel, + metadata.message_id, + metadata.model_dump_json(), + ) + + async def _write_index_record( + self, + channel: discord.TextChannel, + index: int, + hashed: int, + record_id: int, + ) -> None: + """ + Write an index record to the specified channel, using + a known hash and index. + + Args: + channel: Target index channel to store the index record at. + index: Index to store the record at in the table. + hashed: Integer hash of the original value e.g. from `_hash`. + record_id: Message ID of the record in the main table. + """ + entry_message: discord.Message = await self._lookup_message( + channel, + index, + ) + serialized_content = _IndexableRecord.from_message( + entry_message.content + ) + + if not serialized_content: + logger.info("This is a null entry, we can just update in place.") + await self._inc_records() + message_content = _IndexableRecord( + key=hashed, + record_ids=[ + record_id, + ], + ) + await entry_message.edit(content=message_content.model_dump_json()) + elif serialized_content.key == hashed: + # See https://github.com/ZeroIntensity/discobase/issues/50 + # + # We don't want to call _inc_records() here, because we aren't + # using up a `null` space. + logger.info("This already exists, let's append to the data.") + serialized_content.record_ids.append(record_id) + await entry_message.edit( + content=serialized_content.model_dump_json() + ) + else: + logger.info("Hash collision!") + await self._inc_records() + index_message = await self._find_collision_message( + channel, + index, + ) + collision_entry = _IndexableRecord( + key=hashed, + record_ids=[ + record_id, + ], + ) + await index_message.edit(content=collision_entry.model_dump_json()) + + async def add_record(self, record: Table) -> discord.Message: + """ + Writes a record to an existing table. + + Args: + record: The record object being written to the table + + Returns: + discord.Message: The `discord.Message` that contains the new entry. + """ + + metadata = self.metadata + record_data = _Record.from_data(record) + main_table: discord.TextChannel = self._find_channel( + metadata.table_channel + ) + message = await main_table.send( + record_data.model_dump_json(), silent=True + ) + + for field, value in record.model_dump().items(): + channel = self._find_channel( + metadata.index_channels[f"{record.__disco_name__}_{field}"] + ) + hashed_field, target_index = self._as_hashed(value) + await self._write_index_record( + channel, + target_index, + hashed_field, + message.id, + ) + + return await message.edit(content=record_data.model_dump_json()) + + async def update_record(self, record: Table) -> discord.Message: + """ + Updates an existing record in a table. + + Args: + record: The record object being written to the table + + Returns: + discord.Message: The `discord.Message` that contains the new entry. + """ + if record.__disco_id__ == -1: + # Sanity check + raise DatabaseCorruptionError("record must have an id to update") + + metadata = self.metadata + main_table: discord.TextChannel = self._find_channel( + metadata.table_channel + ) + msg = await main_table.fetch_message(record.__disco_id__) + current = _Record.model_validate_json(msg.content).decode_content( + record + ) + await msg.edit(content=_Record.from_data(record).model_dump_json()) + + for new, old in zip( + record.model_dump().items(), + current.model_dump().items(), + ): + field = new[0] + if field != old[0]: + raise DatabaseCorruptionError( + f"field name {field} does not match {old[0]}" + ) + + new_value = new[1] + old_value = old[1] + if new_value == old_value: + logger.info("Nothing changed.") + continue + + channel = self._find_channel( + metadata.index_channels[f"{record.__disco_name__}_{field}"] + ) + hashed_field, target_index = self._as_hashed(new_value) + await self._write_index_record( + channel, + target_index, + hashed_field, + msg.id, + ) + + old_index = self._to_index(self._hash(old_value)) + old_msg = await self._lookup_message(channel, old_index) + old_record = _IndexableRecord.from_message(old_msg.content) + if not old_record: + raise DatabaseCorruptionError( + "got null record somehow", + ) + + if len(old_record.record_ids) == 1: + logger.info("We can nullify this entry.") + await old_msg.edit(content="null") + self.metadata.current_records -= 1 + else: + logger.info( + "There are other entries with this value, only remove this ID." # noqa + ) + old_record.record_ids.remove(msg.id) + await old_msg.edit(content=old_record.model_dump_json()) + + return msg + + async def find_records( + self, + table: type[Table], + query: dict[str, Any], + ) -> list[Table]: + """ + Find a record based on the specified field values. + + Args: + table: Table type to find records for. + query: Dictionary containing field-value pairs. + + Returns: + list[Table]: A list of `Table` objects (or really, a list of + objects that inherit from `Table`), with the appropriate values + specified by `query`. + """ + metadata = self.metadata + name = table.__disco_name__ + sets_list: list[set[int]] = [] + + logger.debug(f"Looking for query {query!r} in {name}") + for field, value in query.items(): + if field not in metadata.keys: + raise DatabaseLookupError( + f"table {metadata.name} has no field {field}" + ) + + channel = self._find_channel( + metadata.index_channels[f"{name}_{field}"] + ) + + hashed_field, target_index = self._as_hashed(value) + entry_message = await self._lookup_message( + channel, + target_index, + ) + + serialized_content = _IndexableRecord.from_message( + entry_message.content + ) + + if not serialized_content: + logger.info("Nothing was found.") + continue + + if serialized_content.key == hashed_field: + logger.debug(f"Key matches hash! {serialized_content}") + sets_list.append(set(serialized_content.record_ids)) + else: + # Hash collision! + def find_hash(message: str | None) -> bool: + if not message: + return False + + index_record = _IndexableRecord.from_message(message) + if not index_record: + return False + + return index_record.key == hashed_field + + entry = await self._find_collision_message( + channel, + target_index, + search_func=find_hash, + ) + + rec = _IndexableRecord.from_message(entry.content) + logger.debug(f"Found hash collision index entry: {rec}") # noqa + if not rec: + # This shouldn't be possible, considering the + # search function explicitly disallows that. + raise DatabaseCorruptionError( + "search function found null entry somehow" + ) + + sets_list.append(set(rec.record_ids)) + + if not query: + logger.info("Query is empty, finding all entries!") + channel = self._find_channel(metadata.table_channel) + async for msg in channel.history(limit=None): + logger.debug(f"Found message in channel: {msg}") + sets_list.append({msg.id}) + + main_table = self._find_channel(metadata.table_channel) + if not isinstance(main_table, discord.TextChannel): + raise DatabaseCorruptionError( + f"expected {main_table!r} to be a TextChannel" + ) + + logger.debug(f"Got IDs: {sets_list}") + records: list[Table] = [] + + for record_ids in sets_list: + for record_id in record_ids: + message = await main_table.fetch_message(record_id) + record = _Record.model_validate_json(message.content) + entry = record.decode_content(table) + entry.__disco_id__ = message.id + records.append(entry) + + return records + + async def _gen_key_channel( + self, + table: str, + key_name: str, + *, + initial_size: int = 4, + ) -> tuple[str, int, int]: + """ + Generate a key channel from the given information. + This does not check if it exists. + + Args: + table: Processed channel name of the table. + key_name: Name of the key, per `__disco_keys__`. + initial_size: Equivalent to `initial_size` in `create_table`. + + Returns: + tuple[int, int, int]: Tuple containing the channel name, + the ID of the created channel, and the snowflake time of the + last message created in the hash table. + """ + # Key channels are stored in + # the format of _ + index_channel = await self.guild.create_text_channel( + f"{table}_{key_name}" + ) + logger.debug(f"Generated key channel: {index_channel}") + last_message_snowflake = await self._resize_hash( + index_channel, initial_size + ) + return index_channel.name, index_channel.id, last_message_snowflake + + @classmethod + async def create_table( + cls, + table: type[Table], + metadata_channel: discord.TextChannel, + guild: discord.Guild, + initial_size: int = 4, + ) -> TableCursor: + """ + Creates a new table and all index tables that go with it. + This writes the table metadata. + + If the table already exists, this method does (almost) nothing. + + Args: + table: Table schema to create channels for. + initial_hash_size: the size the index hash tables should start at. + + Returns: + TableCursor: An object used to manage a table + """ + + logger.debug(f"create_table called with table: {table!r}") + name = table.__disco_name__ + existing_metadata: Metadata | None = None + + async for msg in metadata_channel.history(limit=None): + try: + parsed_meta = Metadata.model_validate_json(msg.content) + except ValidationError as e: + raise DatabaseCorruptionError("got invalid metadata") from e + + if parsed_meta.name == name: + logger.debug( + f"Found existing metadata for table {name}: {parsed_meta}" + ) + existing_metadata = parsed_meta + break + + if existing_metadata and ( + set(existing_metadata.keys) != table.__disco_keys__ + ): + logger.error( + f"stored keys: {', '.join(existing_metadata.keys)} -- table keys: {', '.join(table.__disco_keys__)}" # noqa + ) + raise DatabaseCorruptionError(f"schema for table {name} changed") + + matching: list[str] = [] + for channel in guild.channels: + for key in table.__disco_keys__: + if channel.name == f"{name}_{key}": + matching.append(key) + + if existing_metadata and matching: + if not len(matching) == len(table.__disco_keys__): + raise DatabaseCorruptionError( + f"only some key channels exist: {', '.join(matching)}", + ) + + logger.info(f"Table is already set up: {table.__disco_name__}") + cursor = TableCursor(existing_metadata, metadata_channel, guild) + table.__disco_cursor__ = cursor + return cursor + + logger.info(f"Building table: {table.__disco_name__}") + + # The primary table holds the actual records + primary_table = await guild.create_text_channel(name) + logger.debug(f"Generated primary table: {primary_table!r}") + + metadata = Metadata( + name=name, + keys=tuple(table.__disco_keys__), + table_channel=primary_table.id, + index_channels={}, + current_records=0, + max_records=initial_size, + time_table={}, + message_id=0, + ) + self = TableCursor(metadata, metadata_channel, guild) + timestamp_snowflake: int | None = None + + index_channels: dict[str, int] = {} + # This is ugly, but this is fast: we generate + # the key channels in parallel. + for data in await asyncio.gather( + *[ + self._gen_key_channel( + name, + key_name, + initial_size=initial_size, + ) + for key_name in table.__disco_keys__ + ] + ): + channel_name, channel_id, timestamp_snowflake = data + index_channels[channel_name] = channel_id + + assert timestamp_snowflake is not None + metadata.time_table = {timestamp_snowflake: (0, initial_size)} + metadata.index_channels = index_channels + message = await self.metadata_channel.send( + metadata.model_dump_json(), silent=True + ) + + table.__disco_cursor__ = self + # Since Discord generates the message ID, we have to do these + # message editing shenanigans. + metadata.message_id = message.id + await message.edit(content=metadata.model_dump_json()) + logger.debug(f"Generated table metadata: {metadata!r}") + return self + + async def delete_record(self, record: Table) -> None: + """ + Deletes an existing record in a table. + + Args: + record: The record object being deleted from the table. + """ + if record.__disco_id__ == -1: + # Sanity check + raise DatabaseCorruptionError("record must have an id to update") + + metadata = self.metadata + main_table: discord.TextChannel = self._find_channel( + metadata.table_channel + ) + msg = await main_table.fetch_message(record.__disco_id__) + current = _Record.model_validate_json(msg.content).decode_content( + record + ) + + for field, value in current.model_dump().items(): + channel = self._find_channel( + metadata.index_channels[f"{current.__disco_name__}_{field}"] + ) + + index = self._to_index(self._hash(value)) + index_message = await self._lookup_message(channel, index) + index_record = _IndexableRecord.from_message(index_message.content) + + if not index_record: + raise DatabaseCorruptionError("got null record somehow") + + if len(index_record.record_ids) == 1: + logger.info("We can nullify this entry.") + await index_message.edit(content="null") + self.metadata.current_records -= 1 + else: + logger.info( + "There are other entries with this value, only remove this ID." # noqa + ) + index_record.record_ids.remove(msg.id) + await index_message.edit( + content=index_record.model_dump_json(), + ) + + record.__disco_id__ = -1 + await msg.delete() diff --git a/spunky-sputniks/src/discobase/_metadata.py b/spunky-sputniks/src/discobase/_metadata.py new file mode 100644 index 0000000..ecda8c1 --- /dev/null +++ b/spunky-sputniks/src/discobase/_metadata.py @@ -0,0 +1,24 @@ +from typing import Dict, Tuple + +from pydantic import BaseModel + + +class Metadata(BaseModel): + name: str + """The table name.""" + keys: Tuple[str, ...] + """A tuple containing the name of all keys/fields of the table.""" + table_channel: int + """Channel ID that holds the main table content.""" + index_channels: Dict[str, int] + """A dictionary containing index channel names with index channel IDs.""" + current_records: int + """Number of (used) records in the table.""" + max_records: int + """Capacity of the table (i.e. the "maximum records" that is can hold).""" + time_table: Dict[ + int, Tuple[int, int] + ] # Pydantic doesn't support range objects + """Table of UNIX timestamp -> index range.""" + message_id: int + """ID of the metadata message.""" diff --git a/spunky-sputniks/src/discobase/_util.py b/spunky-sputniks/src/discobase/_util.py new file mode 100644 index 0000000..16e7091 --- /dev/null +++ b/spunky-sputniks/src/discobase/_util.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +from typing import Any, Coroutine, TypeVar + +import discord +from loguru import logger + +__all__ = "gather_group", "GatherGroup", "free_fly" + +T = TypeVar("T") + + +class GatherGroup: + def __init__(self) -> None: + self.tasks: list[asyncio.Task] = [] + + def add(self, awaitable: Coroutine[Any, Any, T]) -> asyncio.Task[T]: + async def inner_coro(): + while True: + try: + return await awaitable + except discord.HTTPException as e: + if e.code == 429: + logger.warning("Ratelimited! Retrying...") + await asyncio.sleep(0.1) + + task = asyncio.create_task(inner_coro()) + self.tasks.append(task) + return task + + +# A partial reimplementation of asyncio.TaskGroup, but +# that's only on 3.11+ anyway. +@asynccontextmanager +async def gather_group(): + group = GatherGroup() + + try: + yield group + finally: + logger.debug(f"Gathering tasks: {group.tasks}") + await asyncio.gather(*group.tasks) + + +_TASKS = set() + + +def free_fly(coro: Coroutine[Any, Any, T]) -> asyncio.Task[T]: + task = asyncio.create_task(coro) + _TASKS.add(task) + task.add_done_callback(_TASKS.discard) + return task diff --git a/spunky-sputniks/src/discobase/cogs/utility.py b/spunky-sputniks/src/discobase/cogs/utility.py new file mode 100644 index 0000000..6cca479 --- /dev/null +++ b/spunky-sputniks/src/discobase/cogs/utility.py @@ -0,0 +1,256 @@ +import json + +import discord +from discord import app_commands +from discord.ext import commands +from loguru import logger +from pydantic import ValidationError + +from ..ui import embed as em + + +class Utility(commands.Cog): + """ + All the slash commands for querying information from the database. + """ + + def __init__(self, bot: commands.Bot) -> None: + self.bot = bot + self.db = self.bot.db + + @app_commands.command(description="Insert new data into a table.") + @app_commands.describe( + table="Choose the table you want to insert the data into.", + data="The data that is to be inserted.", + ) + async def insert( + self, + interaction: discord.Interaction, + table: discord.TextChannel, + data: str, + ) -> None: + logger.info("Called 'insert' command.") + await interaction.response.send_message( + content=f"Looking for `{table.name}`..." + ) + + table_name = table.name.replace("-", " ").lower() + + try: + table_obj = self.db.tables[table_name] # Table object + except IndexError as e: + logger.error(e) + await interaction.edit_original_response( + content=f"The table `{table_name}` does not exist." + ) + return + + try: + await interaction.edit_original_response( + content=f"Table `{table_name}` found! Adding data to table..." + ) + data_dict: dict = json.loads(data) + except TypeError as e: + logger.error(e) + await interaction.edit_original_response( + content=f"The data you entered was not in json format.\nEntered data: {data}" + ) + return + + try: + logger.info("Adding new data to table") + new_entry = table_obj(**data_dict) + except ValidationError as e: + logger.error(e) + await interaction.edit_original_response( + content=f"You are missing one of the following columns in your data: `{table_obj.__disco_keys__}`." + ) + return + + await new_entry.save() + + await interaction.edit_original_response( + content=f"I have inserted `{data}` into `{table_name}` table." + ) + + @app_commands.command( + description="Finds a record with the specific column and value in the table." + ) + @app_commands.describe( + table="The name of the table the column is in.", + column="The name of the column.", + value="The value to search for.", + ) + async def find( + self, + interaction: discord.Interaction, + table: discord.TextChannel, + column: str, + value: str, + ) -> None: + table_info: list | None = None + results: list | None = None + column: str = column.lower() + results_found: int | None = None + results_str: str = "" + + logger.debug("Find slash cmd initialised.") + await interaction.response.send_message( + content=f"Searching for `{value}`..." + ) + + if table.name in self.bot.db.tables: + table_info = self.bot.db.tables[table.name] + + if column in table_info.__disco_keys__: + results = await table_info.find(**{column: value}) + results_found = len(results) + if results_found > 0: + for count, value in enumerate(results, start=1): + results_str += f"**{count}**. {str(value)}\n" + + embed = em.EmbedFromContent( + title=f"Search Result - {results_found} Record(s) Found", + content=[], + headers=None, + style=em.EmbedStyle.DEFAULT, + ).create() + + embed.description = results_str + + await interaction.edit_original_response( + content="", embed=embed + ) + else: + await interaction.edit_original_response( + content="The record could not be found." + ) + else: + await interaction.edit_original_response( + content="Either the table doesn't exist or the column doesn't exist." + ) + + @app_commands.command(description="Modifies a record with a new value.") + @app_commands.describe( + table="Choose the table you want to perform an update on.", + column="Choose the column you want to update.", + current_value="The current value saved in the column.", + new_value="Your new information.", + ) + async def update( + self, + interaction: discord.Interaction, + table: discord.TextChannel, + column: str, + current_value: str, + new_value: str, + ) -> None: + logger.debug("Update slash cmd initialized.") + table_info: list | None = None + if table.name in self.db.tables: + await interaction.response.send_message( + content=f"Table `{table.name}` found! Searching for record..." + ) + table_info = self.db.tables[table.name] + + try: + if column in table_info.__disco_keys__: + column_name = [ + col + for col in table_info.__disco_keys__ + if col.lower() == column.lower() + ][0] + found_table = ( + await table_info.find(**{column_name: current_value}) + )[0] + setattr(found_table, column_name, new_value) + found_table.update() + await interaction.edit_original_response( + content=f"Successfully updated the value of **{column}** in **{table.name}**." + ) + else: + await interaction.edit_original_response( + content="The column does not exist." + ) + except ValidationError: + await interaction.edit_original_response( + content=f"`{new_value}` could not be converted to the field's data type, use `/schema` to " + f"check the data type of the column before trying again." + ) + else: + await interaction.edit_original_response( + content="There is no table with that name, try creating a table." + ) + + @app_commands.command(description="Deletes a record from a table.") + @app_commands.describe( + table="The table from which you want to delete", + record="The record you want to delete - formatted as a json.", + ) + async def delete( + self, + interaction: discord.Interaction, + table: discord.TextChannel, + record: str, + ) -> None: + logger.debug("Delete cmd initialized.") + await interaction.response.send_message( + content=f"Searching for table `{table.name}`..." + ) + + try: + table_obj = self.db.tables[table.name] + await interaction.edit_original_response( + content=f"Table `{table_obj.__disco_name__}` found! Searching for record..." + ) + except IndexError as e: + logger.error(e) + await interaction.edit_original_response( + content=f"The table `{table.name}` does not exist." + ) + return + + try: + record_dict = json.loads(record) + table_record = await table_obj.find(**record_dict) + table_record = table_record[0] + if table_record is None: + raise ValueError + await interaction.edit_original_response( + content=f"Record `{record}` found! Deleting..." + ) + except ValueError as e: + logger.error(e) + await interaction.edit_original_response( + content=f"No record found for `{record}`." + ) + return + except TypeError as e: + logger.error(e) + await interaction.edit_original_response( + content=f"The record you entered was not in json format.\nEntered record: {record}" + ) + return + + await table_record.delete() + + await interaction.edit_original_response( + content=f"Record `{record}` has been deleted from `{table.name}`!" + ) + + @app_commands.command( + description="Resets the database, deleting all channels and unloading tables." + ) + async def reset(self, interaction: discord.Interaction) -> None: + logger.debug("Reset cmd initialized.") + await interaction.response.send_message( + content=f"Resetting the database, `{self.db.name}`..." + ) + await self.db.clean() + await interaction.edit_original_response( + content=f"Database `{self.db.name}` has been reset!" + ) + + +async def setup(bot) -> None: + await bot.add_cog(Utility(bot)) diff --git a/spunky-sputniks/src/discobase/cogs/visualization.py b/spunky-sputniks/src/discobase/cogs/visualization.py new file mode 100644 index 0000000..344aa23 --- /dev/null +++ b/spunky-sputniks/src/discobase/cogs/visualization.py @@ -0,0 +1,215 @@ +import discord +from discord import app_commands +from discord.ext import commands +from loguru import logger + +from ..ui import embed as em + + +class Visualization(commands.Cog): + """ + Slash commands to visualize the database's data. + """ + + def __init__(self, bot: commands.Bot) -> None: + self.bot = bot + self.db = self.bot.db + + @app_commands.command() + async def hello(self, interaction): + await interaction.response.send_message("Hello!") + + @app_commands.command(description="View the selected table.") + @app_commands.describe(name="The name of the table.") + async def table( + self, interaction: discord.Interaction, name: discord.TextChannel + ) -> None: + logger.debug("Table slash cmd initialized.") + await interaction.response.send_message( + content=f"Searching for table `{name}`..." + ) + table_name = name.name.replace("-", " ").lower() + + try: + table = self.db.tables[table_name] + await interaction.edit_original_response( + content=f"Table `{table_name}` found! Gathering data..." + ) + except IndexError as e: + logger.error(e) + await interaction.edit_original_response( + content=f"The table `{name.name}` does not exist." + ) + return + + table_columns = [ + col for col in table.__disco_keys__ + ] # convert set to list to enable subscripting + + data: dict[str:list] = {} + for col in table_columns: + data[col] = [] + + table_values = await table.find() + logger.info(table_values) + + await interaction.edit_original_response( + content="Still gathering data..." + ) + + for game in table_values: + for col in table_columns: + data[col].append(getattr(game, col)) + + embed_from_content = em.EmbedFromContent( + title=f"Table: {table.__disco_name__.title()}", + content=data, + headers=table_columns, + style=em.EmbedStyle.TABLE, + ) + embeds = embed_from_content.create() + + view = em.ArrowButtons(content=embeds) + + await interaction.edit_original_response( + content="", embed=embeds[0], view=view + ) + + @app_commands.command(description="View the column data.") + @app_commands.describe( + table="The name of the table the column is in.", + name="The name of the column.", + ) + async def column( + self, + interaction: discord.Interaction, + table: discord.TextChannel, + name: str, + ) -> None: + logger.debug("Column slash cmd initialized.") + await interaction.response.send_message( + f"Searching for table `{table.name}`..." + ) + try: + col_table = self.db.tables[table.name] + await interaction.edit_original_response( + content=f"Table `{col_table.__disco_name__}` found! Gathering column data..." + ) + except IndexError as e: + logger.error(e) + await interaction.edit_original_response( + content=f"The table `{table.name}` does not exist." + ) + return + + try: + column = [ + col + for col in col_table.__disco_keys__ + if col.lower() == name.lower() + ][0] + except (IndexError, ValueError) as e: + logger.error(e) + await interaction.edit_original_response( + content=f"The column `{name}` does not exist in the table `{col_table.__disco_name__}`." + ) + return + + table_records = await col_table.find() + column_values = [getattr(record, column) for record in table_records] + + embeds = em.EmbedFromContent( + title=f"Column `{name.title()}` from Table `{col_table.__disco_name__.title()}`", + content=column_values, + headers=None, + style=em.EmbedStyle.COLUMN, + ).create() + + view = em.ArrowButtons(content=embeds) + + await interaction.edit_original_response( + content="", embed=embeds[0], view=view + ) + + @app_commands.command( + description="Displays the number of tables and the names of the tables." + ) + async def tablestats(self, interaction: discord.Interaction) -> None: + logger.debug("Tablestats slash cmd initialized.") + await interaction.response.send_message( + content="Getting table statistics..." + ) + try: + tables_names: list | None = None + tables_names = [table for table in self.db.tables] + logger.debug(tables_names) + combined_tables_names = "\n".join(tables_names) + + embed_gen = em.EmbedFromContent( + title="Tables", + content=[], + headers=None, + style=em.EmbedStyle.DEFAULT, + ).create() + + embed_gen.add_field( + name="Number of Tables", + value=len(self.db.tables), + ) + + embed_gen.add_field( + name="Names of Tables", + value=combined_tables_names, + ) + + await interaction.edit_original_response( + content="", embed=embed_gen + ) + except Exception as e: + logger.exception(e) + return + + @app_commands.command( + description="Retrieves and displays the schema for the table.", + ) + @app_commands.describe( + table="Choose the table you want to retrieve the schema from.", + ) + async def schema( + self, interaction: discord.Interaction, table: discord.TextChannel + ) -> None: + logger.debug("Schema slash cmd initialized.") + await interaction.response.send_message( + content=f"Getting schema for {table.name}..." + ) + table_info: list | None = None + table_schema: dict | None = None + schemas: list[dict] | None = None + embed_gen: discord.Embed | None = None + + if table.name in self.db.tables: + table_info = self.db.tables[table.name] + table_schema = table_info.model_json_schema() + schemas = [ + table_schema["properties"][disco_key] + for disco_key in table_info.__disco_keys__ + ] + + embed_gen = em.EmbedFromContent( + title=f"Table: {table.name.title()}", + content=schemas, + headers=None, + style=em.EmbedStyle.SCHEMA, + ).create() + + await interaction.edit_original_response( + content="", embed=embed_gen + ) + else: + await interaction.edit_original_response( + content="There is no table with that name, try creating a table." + ) + + +async def setup(bot: commands.Bot) -> None: + await bot.add_cog(Visualization(bot)) diff --git a/spunky-sputniks/src/discobase/database.py b/spunky-sputniks/src/discobase/database.py new file mode 100644 index 0000000..11f552b --- /dev/null +++ b/spunky-sputniks/src/discobase/database.py @@ -0,0 +1,459 @@ +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +from typing import Coroutine, NoReturn, Type, TypeVar + +import discord +from discord.ext import commands +from loguru import logger + +from ._cursor import TableCursor +from .cogs.utility import Utility +from .cogs.visualization import Visualization +from .exceptions import (DatabaseCorruptionError, DatabaseTableError, + NotConnectedError) +from .table import Table + +__all__ = ("Database",) + +T = TypeVar("T", bound=Type[Table]) + + +class Database: + """ + Top level class representing a Discord + database bot controller. + """ + + def __init__( + self, + name: str, + logging: bool = False, + ) -> None: + """ + Args: + name: Name of the Discord server that will be used as the database. + logging: Whether to enable logging. + """ + if logging: + logger.enable("discobase") + else: + logger.disable("discobase") + + self.name = name + """Name of the Discord-database server.""" + self.bot = commands.Bot( + intents=discord.Intents.all(), + command_prefix="!", + ) + """discord.py `Bot` instance.""" + self.guild: discord.Guild | None = None + """discord.py `Guild` used as the database server.""" + self.tables: dict[str, type[Table]] = {} + """Dictionary of `Table` objects attached to this database.""" + self.open: bool = False + """Whether the database is connected.""" + self._metadata_channel: discord.TextChannel | None = None + """discord.py `TextChannel` that acts as the metadata channel.""" + self._database_cursors: dict[str, TableCursor] = {} + """A dictionary containing all of the table `Metadata` entries""" + self._task: asyncio.Task[None] | None = None + self.bot.db = self # type: ignore + # We need to keep a strong reference to the free-flying + # task + self._setup_event = asyncio.Event() + self._internal_setup_event = asyncio.Event() + self._on_ready_exc: BaseException | None = None + + # Here be dragons: https://github.com/ZeroIntensity/discobase/issues/49 + # + # `on_ready` in discord.py swallows all exceptions, which + # goes against some connect-and-disconnect behavior + # that we want to allow in discobase. + # + # We need to store the exception, and then raise in wait_ready() + # in order to properly handle it, otherwise the discord.py + # logger just swallows it and pretends nothing happened. + # + # This also caused a deadlock with _setup_event, which caused + # CI to run indefinitely. + @self.bot.event + @logger.catch(reraise=True) + async def on_ready() -> None: + try: + await self.bot.add_cog(Utility(self.bot)) + await self.bot.add_cog(Visualization(self.bot)) + await self.init() + except BaseException as e: + await self.bot.close() + if self._task: + self._task.cancel("bot startup failed") + + self._setup_event.set() + self._on_ready_exc = e + raise # This is swallowed! + + def _not_connected(self) -> NoReturn: + """ + Complain about the database not being connected. + + Generally speaking, this is called when `guild` or something + other is `None`. + """ + + raise NotConnectedError( + "not connected to the database, did you forget to login?" + ) + + async def _metadata_init(self) -> discord.TextChannel: + """ + Find the metadata channel. + If it doesn't exist, this method creates one. + + Returns: + discord.TextChannel: The metadata channel, either created or found. + """ + metadata_channel_name = "_dbmetadata" + found_channel: discord.TextChannel | None = None + assert self.guild is not None + + for channel in self.guild.text_channels: + if channel.name == metadata_channel_name: + found_channel = channel + logger.info("Found metadata channel!") + break + + return found_channel or await self.guild.create_text_channel( + name=metadata_channel_name + ) + + # This needs to be async for use in gather() + async def _set_open(self) -> None: + logger.debug("_set_open waiting on internal setup event") + await self._internal_setup_event.wait() + logger.debug( + "Internal setup event dispatched! Database has been marked as open." # noqa + ) + self.open = True + # See https://github.com/ZeroIntensity/discobase/issues/68 + # + # If `wait_ready()` is never called, then the error does not propagate. + if self._on_ready_exc: + raise self._on_ready_exc + + async def init(self) -> None: + """ + Initializes the database server. + + Generally, you don't want to call this manually, but + this is considered to be a public interface. + """ + logger.info("Waiting until bot is logged in.") + await self.bot.wait_until_ready() + logger.info("Bot is ready!") + found_guild: discord.Guild | None = None + for guild in self.bot.guilds: + if guild.name == self.name: + found_guild = guild + break + + if not found_guild: + logger.info("No guild found, making one.") + self.guild = await self.bot.create_guild(name=self.name) + else: + logger.info("Found an existing guild.") + self.guild = found_guild + + # Unlock database, but don't wakeup the user. + self._internal_setup_event.set() + await self.build_tables() + # At this point, the database is marked as "ready" to the user. + self._setup_event.set() + + assert self._metadata_channel is not None + logger.info( + f"Invite to server: {await self._metadata_channel.create_invite()}" + ) + logger.info("Syncing slash commands, this might take a minute.") + logger.debug(f"Synced slash commands: {await self.bot.tree.sync()}") + + async def build_tables(self) -> None: + """ + Generate all internal metadata and construct tables. + + Calling this manually is useful if e.g. you want + to load tables *after* calling `login` (or `login_task`, or + `conn`, same difference.) + + This method is safe to call multiple times. + + Example: + ```py + import asyncio + import discobase + + async def main(): + db = discobase.Database("My database") + db.login_task("My bot token") + + @db.table + class MyLateTable(discobase.Table): + something: str + + # Load MyLateTable into database + await db.build_tables() + + asyncio.run(main()) + ``` + """ + if not self.guild: + self._not_connected() + + self._metadata_channel = await self._metadata_init() + tasks = [ + asyncio.ensure_future( + TableCursor.create_table( + table, + self._metadata_channel, + self.guild, + ) + ) + for table in self.tables.values() + ] + logger.debug(f"Creating tables with gather(): {tasks}") + await asyncio.gather(*tasks) + + async def wait_ready(self) -> None: + """ + Wait until the database is ready. + """ + logger.info("Waiting until the database is ready.") + await self._setup_event.wait() + logger.info("Done waiting!") + # See #49, we need to propagate errors in `on_ready` here. + if self._on_ready_exc: + logger.error("on_ready() failed, propagating now.") + raise self._on_ready_exc + + def _find_channel(self, cid: int) -> discord.TextChannel: + # TODO: Implement caching for this function. + if not self.guild: + self._not_connected() + + index_channel = [ + channel for channel in self.guild.channels if channel.id == cid + ][0] + + if not isinstance(index_channel, discord.TextChannel): + raise DatabaseCorruptionError( + f"expected {index_channel!r} to be a TextChannel" + ) + + logger.debug(f"Found channel ID {cid}: {index_channel!r}") + return index_channel + + async def clean(self) -> None: + """ + Perform a full clean of the database. + + This method erases all metadata, all entries, and all tables. After + calling this, a server loses any trace of the database ever being + there. + + Note that this does *not* clean the existing tables from memory, but + instead just marks them all as not ready. + + This action is irreversible. + """ + if not self._metadata_channel: + self._not_connected() + + logger.info("Cleaning the database!") + + coros: list[Coroutine] = [] + for table, cursor in self._database_cursors.items(): + metadata = cursor.metadata + logger.info(f"Cleaning table {table}") + table_channel = self._find_channel(metadata.table_channel) + coros.append(table_channel.delete()) + + for cid in metadata.index_channels.values(): + channel = self._find_channel(cid) + coros.append(channel.delete()) + + for schema in self.tables.values(): + schema.__disco_cursor__ = None + + logger.debug(f"Gathering deletion coros: {coros}") + await asyncio.gather(*coros) + logger.info("Deleting database metadata.") + self._database_cursors = {} + await self._metadata_channel.delete() + + async def login(self, bot_token: str) -> None: + """ + Start running the bot. + + Args: + bot_token: Discord API bot token to log in with. + """ + if self.open: + raise RuntimeError( + "connection is already open, did you call login() twice?" + ) + + # We use _set_open() with a gather to keep a finer link + # between the `open` attribute and whether it's actually + # running. + await asyncio.gather(self.bot.start(token=bot_token), self._set_open()) + + def login_task(self, bot_token: str) -> asyncio.Task[None]: + """ + Call `login()` as a free-flying task, instead of + blocking the event loop. + + Note that this method stores a reference to the created + task object, allowing it to be truly "free-flying." + + Args: + bot_token: Discord API bot token to log in to. + + Returns: + asyncio.Task[None]: The created `asyncio.Task` object. + Note that the database will store this internally, so you + don't have to worry about losing the reference. By default, + this task will never get `await`ed, so this function will not + keep the event loop running. If you want to keep the event loop + running, make sure to `await` the returned task object later. + + Example: + ```py + import asyncio + import os + + import discobase + + + async def main(): + db = discobase.Database("test") + task = await db.login_task("...") + await db.wait_ready() + # ... + await task # Keep the event loop running + + asyncio.run(main()) + ``` + """ + task = asyncio.create_task(self.login(bot_token)) + self._task = task + return task + + async def close(self) -> None: + """ + Close the bot connection. + """ + if not self.open: + # If something went wrong in startup, for example, then + # we need to release the setup lock. + self._setup_event.set() + raise ValueError( + "cannot close a connection that is not open", + ) + self.open = False + await self.bot.close() + + @asynccontextmanager + async def conn(self, bot_token: str): + """ + Connect to the bot under a context manager. + This is the recommended method to use for logging in. + + Args: + bot_token: Discord API bot token to log in to. + + Returns: + AsyncGeneratorContextManager: An asynchronous context manager. + See `contextlib.asynccontextmanager` for details. + + Example: + ```py + import asyncio + import os + + import discobase + + + async def main(): + db = discobase.Database("test") + async with db.conn(os.getenv("BOT_TOKEN")): + ... # Your database code + + + asyncio.run(main()) + ``` + """ + try: + self.login_task(bot_token) + await self.wait_ready() + yield + finally: + if self.open: # Something could have gone wrong + await self.close() + + def table(self, clas: T) -> T: + """ + Mark a `Table` type as part of a database. + This method is meant to be used as a decorator. + + Args: + clas: Type object to attach. + + Example: + ```py + import discobase + + db = discobase.Database("My database") + + @db.table + class MySchema(discobase.Table): + foo: int + bar: str + + # ... + ``` + + Returns: + Type[Table]: The same object passed to `clas` -- this is in order + to allow use as a decorator. + """ + if not issubclass(clas, Table): + raise DatabaseTableError( + f"{clas} is not a subclass of discobase.Table, did you forget it?", # noqa + ) + + clas.__disco_name__ = clas.__name__.lower() + if clas.__disco_name__ in self.tables: + raise DatabaseTableError(f"table {clas.__name__} already exists") + + if clas.__disco_database__ is not None: + raise DatabaseTableError( + f"{clas!r} can only be attached to one database" + ) + + clas.__disco_database__ = self + + # This is up for criticism -- instead of using Pydantic's + # `model_fields` attribute, we invent our own `__disco_keys__` instead. + # + # Partially, this is due to the fact that we want `__disco_keys__` to + # be, more or less, stable throughout the codebase. + # + # However, I don't think Pydantic would mess with `model_fields`, as + # that's a public API, and hence why this could possibly be + # considered as bad design. + for field in clas.model_fields: + clas.__disco_keys__.add(field) + + self.tables[clas.__disco_name__] = clas + return clas diff --git a/spunky-sputniks/src/discobase/exceptions.py b/spunky-sputniks/src/discobase/exceptions.py new file mode 100644 index 0000000..1d52cd5 --- /dev/null +++ b/spunky-sputniks/src/discobase/exceptions.py @@ -0,0 +1,34 @@ +class DiscobaseError(Exception): + """ + Base discobase exception class. + """ + + +class NotConnectedError(DiscobaseError): + """ + The database is not connected. + """ + + +class DatabaseCorruptionError(DiscobaseError): + """ + The database was corrupted somehow. + """ + + +class DatabaseStorageError(DiscobaseError): + """ + Failed store something in the database. + """ + + +class DatabaseTableError(DiscobaseError): + """ + Something is wrong with a `Table` type. + """ + + +class DatabaseLookupError(DiscobaseError): + """ + Something went wrong with an entry lookup. + """ diff --git a/spunky-sputniks/src/discobase/py.typed b/spunky-sputniks/src/discobase/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/spunky-sputniks/src/discobase/table.py b/spunky-sputniks/src/discobase/table.py new file mode 100644 index 0000000..3318494 --- /dev/null +++ b/spunky-sputniks/src/discobase/table.py @@ -0,0 +1,240 @@ +from __future__ import annotations + +import asyncio +from typing import (TYPE_CHECKING, Any, ClassVar, Literal, Optional, Set, + overload) + +import discord +from pydantic import BaseModel, ConfigDict +from typing_extensions import Self, Unpack + +from ._util import free_fly + +if TYPE_CHECKING: + from .database import Database + +from ._cursor import TableCursor +from .exceptions import (DatabaseLookupError, DatabaseStorageError, + DatabaseTableError, NotConnectedError) + +__all__ = ("Table",) + + +# Note that we can't use 3.10+ type[] syntax +# here, since Pydantic can't handle it +class Table(BaseModel): + __disco_database__: ClassVar[Optional[Database]] + """Attached `Database` object. Set by the `table()` decorator.""" + __disco_cursor__: ClassVar[Optional[TableCursor]] + """Internal table cursor, set at initialization time.""" + __disco_keys__: ClassVar[Set[str]] + """All keys of the table, this may not change once set by `table()`.""" + __disco_name__: ClassVar[str] + """Internal name of the table. Set by the `table()` decorator.""" + __disco_id__: int = -1 + """Message ID of the record. This is only present if it was saved.""" + + def __init__(self, /, **data: Any) -> None: + super().__init__(**data) + self.__disco_id__ = -1 + + def __init_subclass__(cls, **kwargs: Unpack[ConfigDict]) -> None: + super().__init_subclass__(**kwargs) + cls.__disco_database__ = None + cls.__disco_cursor__ = None + cls.__disco_keys__ = set() + cls.__disco_name__ = "_notset" + + @classmethod + def _ensure_db(cls) -> None: + if not cls.__disco_database__: + raise DatabaseTableError( + f"{cls.__name__} has no attached database, did you forget to call @db.table?" # noqa + ) + + if not cls.__disco_database__.open: + raise NotConnectedError( + "database is not connected! did you forget to open it?" + ) + + if not cls.__disco_cursor__: + raise DatabaseTableError( + f"{cls.__name__} is not ready, you might want to add a call to build_tables()", # noqa + ) + + def save(self) -> asyncio.Task[discord.Message]: + """ + Save the entry to the database as a new record. + + Example: + ```py + import discobase + + db = discobase.Database("My database") + + @db.table + class User(discobase.Table): + name: str + password: str + + # Using top-level await for this example + await User(name="Peter", password="foobar").save() + ``` + """ + self._ensure_db() + assert self.__disco_cursor__ + + if self.__disco_id__ != -1: + raise DatabaseStorageError( + "this entry has already been written, did you mean to call update()?", # noqa + ) + task = free_fly(self.__disco_cursor__.add_record(self)) + + def _cb(fut: asyncio.Task[discord.Message]) -> None: + msg = fut.result() + self.__disco_id__ = msg.id + + task.add_done_callback(_cb) + return task + + def _ensure_written(self) -> None: + if self.__disco_id__ == -1: + raise DatabaseStorageError( + "this entry has not been written, did you mean to call save()?", # noqa + ) + + def update(self) -> asyncio.Task[discord.Message]: + """ + Update the entry in-place. + + Example: + ```py + import discobase + + db = discobase.Database("My database") + + @db.table + class User(discobase.Table): + name: str + password: str + + # Using top-level await for this example + user = await User.find_unique(name="Peter", password="foobar") + user.password = str(hash(password)) + await user.update() + ``` + """ + + self._ensure_db() + self._ensure_written() + assert self.__disco_cursor__ + if self.__disco_id__ == -1: + raise DatabaseStorageError( + "this entry has not been written, did you mean to call save()?", # noqa + ) + return free_fly(self.__disco_cursor__.update_record(self)) + + def commit(self) -> asyncio.Task[discord.Message]: + """ + Save the current entry, or update it if it already exists in the + database. + """ + if self.__disco_id__ == -1: + return self.save() + else: + return self.update() + + def delete(self) -> asyncio.Task[None]: + """ + Delete the current entry from the database. + """ + + self._ensure_written() + assert self.__disco_cursor__ + return free_fly(self.__disco_cursor__.delete_record(self)) + + @classmethod + async def find(cls, **kwargs: Any) -> list[Self]: + """ + Find a list of entries in the database. + + Args: + **kwargs: Values to search for. These should be keys in the schema. + + Returns: + list[Table]: The list of objects that match the values in `kwargs`. + + Example: + ```py + import discobase + db = discobase.Database("My database") + @db.table + class User(discobase.Table): + name: str + password: str + # Using top-level await for this example + await User.find(password="foobar").save() + ``` + """ + cls._ensure_db() + assert cls.__disco_cursor__ + return await cls.__disco_cursor__.find_records( + cls, + kwargs, + ) + + @classmethod + @overload + async def find_unique( + cls, + *, + strict: Literal[True] = True, + **kwargs: Any, + ) -> Self: ... + + @classmethod + @overload + async def find_unique( + cls, + *, + strict: Literal[False] = False, + **kwargs: Any, + ) -> Self | None: ... + + @classmethod + async def find_unique( + cls, + *, + strict: bool = True, + **kwargs: Any, + ) -> Self | None: + """ + Find a unique entry in the database. + + Args: + **kwargs: Values to search for. These should be keys in the schema. + + Returns: + Table: Returns a single object that matches the values in`kwargs`. + None: Nothing was found, and `strict` is `False`. + """ + + if not kwargs: + raise ValueError("a query must be passed to find_unique") + + values: list[Self] = await cls.find(**kwargs) + + if not len(values): + if strict: + raise DatabaseLookupError( + f"no entry found with query {kwargs}", + ) + + return None + + if strict and (1 < len(values)): + raise DatabaseLookupError( + "more than one entry was found with find_unique" + ) + + return values[0] diff --git a/spunky-sputniks/src/discobase/ui/embed.py b/spunky-sputniks/src/discobase/ui/embed.py new file mode 100644 index 0000000..42d4748 --- /dev/null +++ b/spunky-sputniks/src/discobase/ui/embed.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +from datetime import datetime as dt +from enum import Enum, auto +from math import ceil + +import discord +from loguru import logger + +""" +How to Use: +1. Use EmbedfromContent for your database outputs with .create() and assign to a variable. + e.g. EmbedfromContent(title="something", content={"foo":"bar"}, headers=["foo"], style="TABLE").create() +2. Use Arrow buttons class with EmbedfromContent output as content argument and assign to variable. +3. Input the ArrowButton class as the view, and the embeds as the content in interaction.send_message. +""" + +__all__ = ["ArrowButtons", "EmbedFromContent", "EmbedStyle"] + + +class ArrowButtons(discord.ui.View): + def __init__(self, content: list[discord.Embed]) -> None: + super().__init__(timeout=None) + self.value = None + self.content = content + self.position = 0 + self.pages = len(self.content) + logger.debug(f"pages in button {self.pages}") + self.on_ready() + + @discord.ui.button( + label="◀", style=discord.ButtonStyle.primary, custom_id="l_button" + ) + async def back( + self, interaction: discord.Interaction, button: discord.ui.Button + ) -> None: + """Controls the left button on the qotd list embed""" + # move back a position in the embed list + self.position -= 1 + + # check if we're on the first page, then disable the button to go left if we are (cant go anymore left) + if self.position == 0: + button.disabled = True + + # set the right button to a variable + right_button = [x for x in self.children if x.custom_id == "r_button"][ + 0 + ] + + # check if we're not on the last page, if yes then enable right button + if not self.position == self.pages - 1: + right_button.disabled = False + + # update discord message + await interaction.response.edit_message( + embed=self.content[self.position], view=self + ) + + @discord.ui.button( + label="▶", style=discord.ButtonStyle.primary, custom_id="r_button" + ) + async def forward( + self, interaction: discord.Interaction, button: discord.ui.Button + ) -> None: + """Controls the right button on the qotd list embed""" + # move forward a position in the embed list + self.position += 1 + + # set a variable for left button + left_button = [x for x in self.children if x.custom_id == "l_button"][0] + # check if we're not on the first page, if yes then enable left button + if not self.position == 0: + left_button.disabled = False + + # check if we're on the last page, if yes then disable right button + if self.position == self.pages - 1: + button.disabled = True + + # update discord message + await interaction.response.edit_message( + embed=self.content[self.position], view=self + ) + + def on_ready(self) -> None: + """Checks the number of pages to decide which buttons to have enabled/disabled""" + left_button = [x for x in self.children if x.custom_id == "l_button"][0] + right_button = [x for x in self.children if x.custom_id == "r_button"][ + 0 + ] + + # if we only have one page, disable both buttons + if self.pages == 1: + left_button.disabled = True + right_button.disabled = True + # if we have more than one page, only disable the left button for the first page + else: + left_button.disabled = True + + +class EmbedStyle(str, Enum): + COLUMN = auto() + TABLE = auto() + SCHEMA = auto() + DEFAULT = auto() + + +# TODO add support for character limits: https://anidiots.guide/.gitbook/assets/first-bot-embed-example.png +class EmbedFromContent: + """Creates a list of embeds suited for pagination from inserted content.""" + + def __init__( + self, + title: str, + content: list[str] | dict | list[dict], + style: "EmbedStyle", + headers: list[str] | None = None, + ) -> None: + """ + Sets the base parameters for the embeds. + + :param title: Title of the embed. + :param headers: Columns of the table, will be used for field names. Required if seeking table display. + :param content: Content of the table. + """ + self.author = "Discobase" + self.color = discord.Colour.blurple() + self.title = title if len(title) < 256 else f"{title[0:253]}..." + self.headers = headers + self.content = content + self.page_number = 0 + self.page_total = 0 + self.url = "https://github.com/ZeroIntensity/discobase" + self.icon_url = "https://i.imgur.com/2QH3tEQ.png" + + self.style = style + + def create(self) -> list[discord.Embed] | discord.Embed: + if self.style == "column": + return self._column_display() + elif self.style == "table": + return self._table_display() + elif self.style == "schema": + return self._schema_display() + elif self.style == "default": + return self._default_display() + else: + raise ValueError("Invalid style input.") + + def _column_display(self) -> list[discord.Embed]: + """ + Creates list of discord embeds for the column content, 15 rows per embed. + """ + entries_per_page = 15 + embeds: list[discord.Embed] = [] + + column_data: list[str] = self.content + self.page_total = ceil(len(column_data) / 15) + logger.debug(f"{self.page_total}, round: {len(column_data) / 15}") + + # Create each embed with the data + for i in range(0, len(column_data), entries_per_page): + self.page_number += 1 + embed_content = "\n".join(column_data[i : i + entries_per_page]) + discord_embed = discord.Embed( + color=self.color, + title=self.title, + type="rich", + description=embed_content, + timestamp=dt.now(), + ) + discord_embed.set_author( + name=self.author, url=self.url, icon_url=self.icon_url + ) + discord_embed.set_footer( + text=f"Page: {self.page_number}/{self.page_total}" + ) + + embeds.append(discord_embed) + + return embeds + + def _table_display(self) -> list[discord.Embed]: + """ + Creates a list of discord embeds that display the data in a table, 10 entries per page. + """ + entries_per_page = 10 + embeds: list[discord.Embed] = [] + + column_names: list = self.headers + table_data: dict = self.content + self.page_total = ceil( + len(self.content[self.headers[0]]) / entries_per_page + ) + + # get the len of the first column's data + for i in range(0, len(table_data[column_names[0]]), entries_per_page): + self.page_number += 1 + discord_embed = discord.Embed( + color=self.color, + title=self.title, + type="rich", + timestamp=dt.now(), + ) + discord_embed.set_author( + name=self.author, url=self.url, icon_url=self.icon_url + ) + discord_embed.set_footer( + text=f"Page: {self.page_number}/{self.page_total}" + ) + # create fields for each column with 10 data entries + for k, v in table_data.items(): + field_title = k.title() + field_content = "\n".join( + [ + f"**{i + 1}.** {value}" + for i, value in enumerate(v[i : i + entries_per_page]) + ] + ) + discord_embed.add_field( + name=field_title, value=field_content, inline=True + ) + embeds.append(discord_embed) + + return embeds + + def _schema_display(self) -> discord.Embed: + """ + Creates an embed that has the schema information. Column names as field titles, and type as field values. + """ + embed = discord.Embed( + title=self.title, color=self.color, type="rich", timestamp=dt.now() + ) + embed.set_author(name=self.author, url=self.url, icon_url=self.icon_url) + + for content in self.content: + embed.add_field( + name=content["title"], value=content["type"], inline=True + ) + + return embed + + def _default_display(self) -> discord.Embed: + """ + Creates an embed with a default visual style. + """ + embed = discord.Embed( + title=self.title, color=self.color, type="rich", timestamp=dt.now() + ) + embed.set_author(name=self.author, url=self.url, icon_url=self.icon_url) + + return embed diff --git a/spunky-sputniks/tests/test_database.py b/spunky-sputniks/tests/test_database.py new file mode 100644 index 0000000..573821d --- /dev/null +++ b/spunky-sputniks/tests/test_database.py @@ -0,0 +1,196 @@ +import os +import random +import string +import sys + +import discord +import pytest +import pytest_asyncio +from pydantic import Field + +import discobase +from discobase.exceptions import DatabaseTableError + + +@pytest_asyncio.fixture(scope="session") +async def database(): + db = discobase.Database("discobase test", logging=True) + db.login_task(os.environ["TEST_BOT_TOKEN"]) + await db.wait_ready() + if db.guild: + await db.guild.delete() + db.guild = None + await db.init() + + try: + yield db + finally: + await db.close() + + +@pytest_asyncio.fixture(scope="session") +def bot(database: discobase.Database): + return database.bot + + +def test_about(): + assert isinstance(discobase.__version__, str) + assert discobase.__license__ == "MIT" + + +@pytest.mark.asyncio(scope="session") +async def test_creation(database: discobase.Database, bot: discord.Client): + found_guild: discord.Guild | None = None + for guild in bot.guilds: + if guild.name == database.name: + found_guild = guild + + assert found_guild == database.guild + + +@pytest.mark.asyncio(scope="session") +async def test_metadata_channel(database: discobase.Database): + assert database._metadata_channel is not None + assert database._metadata_channel.name == "_dbmetadata" + assert database.guild is not None + found: bool = False + + for channel in database.guild.channels: + if channel == database._metadata_channel: + found = True + break + + assert found is True + + +@pytest.mark.asyncio(scope="session") +async def test_schemas(database: discobase.Database): + class Bar(discobase.Table): + name: str + password: str + + with pytest.raises(DatabaseTableError): + # No database attached + await Bar(name="Peter", password="foobar").save() + + with pytest.raises(DatabaseTableError): + # Missing `Table` subclass + @database.table # type: ignore + class Foo: + name: str + password: str + + Bar = database.table(Bar) + with pytest.raises(DatabaseTableError): + Bar = database.table(Bar) + # Duplicate table name + with pytest.raises(DatabaseTableError): + # Not ready + await Bar(name="Peter", password="foobar").save() + + await database.build_tables() + user = Bar(name="Peter", password="foobar") + await user.save() + assert (await Bar.find_unique(name="Peter")) == user + + +@pytest.mark.asyncio(scope="session") +async def test_resizing(database: discobase.Database): + @database.table + class User(discobase.Table): + name: str + password: str + + await database.build_tables() + + things: list[str] = [ + "aa", + "bbbbbb", + f"cc{random.randint(100, 10000)}", + f"{random.randint(1000, 100000)}dd{random.randint(10000, 100000)}", + "".join( + random.choices( + string.ascii_letters, + k=random.randint(10, 40), + ) + ), + ] + + for name in things: + await User(name=name, password="test").save() + + items = await User.find(password="test") + assert len(items) == len(things) + for i in items: + assert i.name in things + + +@pytest.mark.skipif( + sys.version_info[1] != 12, + reason="Very long, only run on 3.12", +) +@pytest.mark.asyncio(scope="session") +async def test_long_resize(database: discobase.Database): + @database.table + class X(discobase.Table): + foo: str + bar: str = Field(default="bar") + + await database.build_tables() + + for char in string.ascii_letters: + await X(foo=char).save() + + items = await X.find() + assert len(items) == len(string.ascii_letters) + for i in items: + assert i.bar == "bar" + assert i.foo in string.ascii_letters + + +# async def test_clean(database: discobase.Database): +# await database.clean() +# +# with pytest.raises(DatabaseTableError): +# +# @database.table +# class User(discobase.Table): +# test: str +# +# @database.table +# class Whatever(discobase.Table): +# foo: str +# +# await Whatever(foo="bar").save() +# await database.clean() +# +# assert len(await Whatever.find()) == 0 + + +@pytest.mark.asyncio(scope="session") +async def test_update(database: discobase.Database): + @database.table + class UpdateTest(discobase.Table): + foo: str + + await database.build_tables() + test = UpdateTest(foo="test") + await test.save() + test.foo = "test again" + await test.update() + assert len(await UpdateTest.find(foo="test")) == 0 + assert len(await UpdateTest.find(foo="test again")) == 1 + + +@pytest.mark.asyncio(scope="session") +async def test_delete(database: discobase.Database): + @database.table + class DeleteTest(discobase.Table): + foo: str + + await database.build_tables() + test = DeleteTest(foo="test") + await test.save() + assert len(await DeleteTest.find()) == 1 + await test.delete() + assert len(await DeleteTest.find()) == 0