Skip to content

Commit 24b0a1d

Browse files
authored
feat(migration): improved template customization (#245)
Improved python migration file support/templating. Allow all fields to be configurable (such as `title` and `author` of migration)
1 parent 2cb5dfe commit 24b0a1d

File tree

12 files changed

+1042
-260
lines changed

12 files changed

+1042
-260
lines changed

docs/usage/migrations.rst

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,81 @@ All database configs (sync and async) provide these migration methods:
139139
``get_current_migration(verbose=False)``
140140
Get the current migration version.
141141

142+
Template Profiles & Author Metadata
143+
===================================
144+
145+
Migrations inherit their header text, metadata comments, and default file format
146+
from ``migration_config["templates"]``. Each project can define multiple
147+
profiles and select one globally:
148+
149+
.. code-block:: python
150+
151+
migration_config={
152+
"default_format": "py", # CLI default when --format omitted
153+
"title": "Acme Migration", # Shared title for all templates
154+
"author": "env:SQLSPEC_AUTHOR", # Read from environment variable
155+
"templates": {
156+
"sql": {
157+
"header": "-- {title} - {message}",
158+
"metadata": ["-- Version: {version}", "-- Owner: {author}"],
159+
"body": "-- custom SQL body"
160+
},
161+
"py": {
162+
"docstring": """{title}\nDescription: {description}""",
163+
"imports": ["from typing import Iterable"],
164+
"body": """def up(context: object | None = None) -> str | Iterable[str]:
165+
return "SELECT 1"
166+
167+
def down(context: object | None = None) -> str | Iterable[str]:
168+
return "DROP TABLE example;"
169+
"""
170+
}
171+
}
172+
}
173+
174+
Template fragments accept the following variables:
175+
176+
- ``{title}`` – shared template title
177+
- ``{version}`` – generated revision identifier
178+
- ``{message}`` – CLI/command message
179+
- ``{description}`` – message fallback used in logs and docstrings
180+
- ``{created_at}`` – UTC timestamp in ISO 8601 format
181+
- ``{author}`` – resolved author string
182+
- ``{adapter}`` – config driver class (useful for docstrings)
183+
- ``{project_slug}`` / ``{slug}`` – sanitized project and message slugs
184+
185+
Missing placeholders raise ``TemplateValidationError`` so mistakes are caught
186+
immediately. SQL templates list metadata rows (``metadata``) and a ``body``
187+
block. Python templates expose ``docstring``, optional ``imports``, and ``body``.
188+
189+
Author attribution can be controlled via ``migration_config["author"]``:
190+
191+
- Literal strings (``"Data Platform"``) are stamped verbatim
192+
- ``"env:VAR_NAME"`` pulls from the environment and fails fast if unset
193+
- ``"callable:pkg.module:get_author"`` invokes a helper that can inspect the
194+
config or environment when determining the author string
195+
- ``"git"`` reads git user.name/email; ``"system"`` uses ``$USER``
196+
197+
CLI Enhancements
198+
----------------
199+
200+
``sqlspec create-migration`` (and ``litestar database create-migration``)
201+
accept ``--format`` / ``--file-type`` flags:
202+
203+
.. code-block:: bash
204+
205+
sqlspec --config myapp.config create-migration -m "Add seed data" --format py
206+
207+
When omitted, the CLI uses ``migration_config["default_format"]`` (``"sql"`` by default).
208+
Upgrade/downgrade commands now echo ``{version}: {description}``, so the rich
209+
description captured in templates is visible during deployments and matches the
210+
continue-on-error logs.
211+
212+
The default Python template ships with both ``up`` and ``down`` functions that
213+
accept an optional ``context`` argument. When migrations run via SQLSpec, that
214+
parameter receives the active ``MigrationContext`` so you can reach the config
215+
or connection objects directly inside your migration logic.
216+
142217
``create_migration(message, file_type="sql")``
143218
Create a new migration file.
144219

@@ -452,7 +527,8 @@ SQLSpec uses a tracking table to record applied migrations:
452527
renamed migrations (e.g., timestamp → sequential conversion).
453528

454529
``applied_by``
455-
Unix username of user who applied the migration.
530+
Author string recorded for the migration. Defaults to the git user/system
531+
account but can be overridden via ``migration_config["author"]``.
456532

457533
Schema Migration
458534
----------------

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ split-on-trailing-comma = false
503503
"docs/examples/**" = ["T201"]
504504
"sqlspec/builder/mixins/**/*.*" = ["SLF001"]
505505
"sqlspec/extensions/adk/converters.py" = ["S403"]
506+
"sqlspec/migrations/utils.py" = ["S404"]
506507
"tests/**/*.*" = [
507508
"A",
508509
"ARG",

sqlspec/cli.py

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import TYPE_CHECKING, Any, cast
77

88
import rich_click as click
9+
from click.core import ParameterSource
910

1011
if TYPE_CHECKING:
1112
from rich_click import Group
@@ -80,6 +81,16 @@ def sqlspec_group(ctx: "click.Context", config: str, validate_config: bool) -> N
8081
return sqlspec_group
8182

8283

84+
def _ensure_click_context() -> "click.Context":
85+
"""Return the active Click context, raising if missing (for type-checkers)."""
86+
87+
context = click.get_current_context()
88+
if context is None: # pragma: no cover - click guarantees context in commands
89+
msg = "SQLSpec CLI commands require an active Click context"
90+
raise RuntimeError(msg)
91+
return cast("click.Context", context)
92+
93+
8394
def add_migration_commands(database_group: "Group | None" = None) -> "Group":
8495
"""Add migration commands to the database group.
8596
@@ -270,17 +281,12 @@ def show_database_revision( # pyright: ignore[reportUnusedFunction]
270281
from sqlspec.migrations.commands import create_migration_commands
271282
from sqlspec.utils.sync_tools import run_
272283

273-
ctx = click.get_current_context()
284+
ctx = _ensure_click_context()
274285

275286
async def _show_current_revision() -> None:
276287
# Check if this is a multi-config operation
277288
configs_to_process = process_multiple_configs(
278-
cast("click.Context", ctx),
279-
bind_key,
280-
include,
281-
exclude,
282-
dry_run=False,
283-
operation_name="show current revision",
289+
ctx, bind_key, include, exclude, dry_run=False, operation_name="show current revision"
284290
)
285291

286292
if configs_to_process is not None:
@@ -300,7 +306,7 @@ async def _show_current_revision() -> None:
300306
console.print(f"[red]✗ Failed to get current revision for {config_name}: {e}[/]")
301307
else:
302308
console.rule("[yellow]Listing current revision[/]", align="left")
303-
sqlspec_config = get_config_by_bind_key(cast("click.Context", ctx), bind_key)
309+
sqlspec_config = get_config_by_bind_key(ctx, bind_key)
304310
migration_commands = create_migration_commands(config=sqlspec_config)
305311
await maybe_await(migration_commands.current(verbose=verbose))
306312

@@ -327,17 +333,12 @@ def downgrade_database( # pyright: ignore[reportUnusedFunction]
327333
from sqlspec.migrations.commands import create_migration_commands
328334
from sqlspec.utils.sync_tools import run_
329335

330-
ctx = click.get_current_context()
336+
ctx = _ensure_click_context()
331337

332338
async def _downgrade_database() -> None:
333339
# Check if this is a multi-config operation
334340
configs_to_process = process_multiple_configs(
335-
cast("click.Context", ctx),
336-
bind_key,
337-
include,
338-
exclude,
339-
dry_run=dry_run,
340-
operation_name=f"downgrade to {revision}",
341+
ctx, bind_key, include, exclude, dry_run=dry_run, operation_name=f"downgrade to {revision}"
341342
)
342343

343344
if configs_to_process is not None:
@@ -371,7 +372,7 @@ async def _downgrade_database() -> None:
371372
else Confirm.ask(f"Are you sure you want to downgrade the database to the `{revision}` revision?")
372373
)
373374
if input_confirmed:
374-
sqlspec_config = get_config_by_bind_key(cast("click.Context", ctx), bind_key)
375+
sqlspec_config = get_config_by_bind_key(ctx, bind_key)
375376
migration_commands = create_migration_commands(config=sqlspec_config)
376377
await maybe_await(migration_commands.downgrade(revision=revision, dry_run=dry_run))
377378

@@ -402,7 +403,7 @@ def upgrade_database( # pyright: ignore[reportUnusedFunction]
402403
from sqlspec.migrations.commands import create_migration_commands
403404
from sqlspec.utils.sync_tools import run_
404405

405-
ctx = click.get_current_context()
406+
ctx = _ensure_click_context()
406407

407408
async def _upgrade_database() -> None:
408409
# Report execution mode when specified
@@ -411,7 +412,7 @@ async def _upgrade_database() -> None:
411412

412413
# Check if this is a multi-config operation
413414
configs_to_process = process_multiple_configs(
414-
cast("click.Context", ctx), bind_key, include, exclude, dry_run, operation_name=f"upgrade to {revision}"
415+
ctx, bind_key, include, exclude, dry_run, operation_name=f"upgrade to {revision}"
415416
)
416417

417418
if configs_to_process is not None:
@@ -449,7 +450,7 @@ async def _upgrade_database() -> None:
449450
)
450451
)
451452
if input_confirmed:
452-
sqlspec_config = get_config_by_bind_key(cast("click.Context", ctx), bind_key)
453+
sqlspec_config = get_config_by_bind_key(ctx, bind_key)
453454
migration_commands = create_migration_commands(config=sqlspec_config)
454455
await maybe_await(
455456
migration_commands.upgrade(revision=revision, auto_sync=not no_auto_sync, dry_run=dry_run)
@@ -465,10 +466,10 @@ def stamp(bind_key: str | None, revision: str) -> None: # pyright: ignore[repor
465466
from sqlspec.migrations.commands import create_migration_commands
466467
from sqlspec.utils.sync_tools import run_
467468

468-
ctx = click.get_current_context()
469+
ctx = _ensure_click_context()
469470

470471
async def _stamp() -> None:
471-
sqlspec_config = get_config_by_bind_key(cast("click.Context", ctx), bind_key)
472+
sqlspec_config = get_config_by_bind_key(ctx, bind_key)
472473
migration_commands = create_migration_commands(config=sqlspec_config)
473474
await maybe_await(migration_commands.stamp(revision=revision))
474475

@@ -488,7 +489,7 @@ def init_sqlspec( # pyright: ignore[reportUnusedFunction]
488489
from sqlspec.migrations.commands import create_migration_commands
489490
from sqlspec.utils.sync_tools import run_
490491

491-
ctx = click.get_current_context()
492+
ctx = _ensure_click_context()
492493

493494
async def _init_sqlspec() -> None:
494495
console.rule("[yellow]Initializing database migrations.", align="left")
@@ -498,11 +499,7 @@ async def _init_sqlspec() -> None:
498499
else Confirm.ask("[bold]Are you sure you want initialize migrations for the project?[/]")
499500
)
500501
if input_confirmed:
501-
configs = (
502-
[get_config_by_bind_key(cast("click.Context", ctx), bind_key)]
503-
if bind_key is not None
504-
else cast("click.Context", ctx).obj["configs"]
505-
)
502+
configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"]
506503

507504
for config in configs:
508505
migration_config = getattr(config, "migration_config", {})
@@ -519,17 +516,25 @@ async def _init_sqlspec() -> None:
519516
)
520517
@bind_key_option
521518
@click.option("-m", "--message", default=None, help="Revision message")
519+
@click.option(
520+
"--format",
521+
"--file-type",
522+
"file_format",
523+
type=click.Choice(["sql", "py"]),
524+
default=None,
525+
help="File format for the generated migration (defaults to template profile)",
526+
)
522527
@no_prompt_option
523528
def create_revision( # pyright: ignore[reportUnusedFunction]
524-
bind_key: str | None, message: str | None, no_prompt: bool
529+
bind_key: str | None, message: str | None, file_format: str | None, no_prompt: bool
525530
) -> None:
526531
"""Create a new database revision."""
527532
from rich.prompt import Prompt
528533

529534
from sqlspec.migrations.commands import create_migration_commands
530535
from sqlspec.utils.sync_tools import run_
531536

532-
ctx = click.get_current_context()
537+
ctx = _ensure_click_context()
533538

534539
async def _create_revision() -> None:
535540
console.rule("[yellow]Creating new migration revision[/]", align="left")
@@ -539,9 +544,11 @@ async def _create_revision() -> None:
539544
"new migration" if no_prompt else Prompt.ask("Please enter a message describing this revision")
540545
)
541546

542-
sqlspec_config = get_config_by_bind_key(cast("click.Context", ctx), bind_key)
547+
sqlspec_config = get_config_by_bind_key(ctx, bind_key)
548+
param_source = ctx.get_parameter_source("file_format")
549+
effective_format = None if param_source is ParameterSource.DEFAULT else file_format
543550
migration_commands = create_migration_commands(config=sqlspec_config)
544-
await maybe_await(migration_commands.revision(message=message_text))
551+
await maybe_await(migration_commands.revision(message=message_text, file_type=effective_format))
545552

546553
run_(_create_revision)()
547554

@@ -557,11 +564,11 @@ def fix_migrations( # pyright: ignore[reportUnusedFunction]
557564
from sqlspec.migrations.commands import create_migration_commands
558565
from sqlspec.utils.sync_tools import run_
559566

560-
ctx = click.get_current_context()
567+
ctx = _ensure_click_context()
561568

562569
async def _fix_migrations() -> None:
563570
console.rule("[yellow]Migration Fix Command[/]", align="left")
564-
sqlspec_config = get_config_by_bind_key(cast("click.Context", ctx), bind_key)
571+
sqlspec_config = get_config_by_bind_key(ctx, bind_key)
565572
migration_commands = create_migration_commands(config=sqlspec_config)
566573
await maybe_await(migration_commands.fix(dry_run=dry_run, update_database=not no_database, yes=yes))
567574

@@ -573,20 +580,20 @@ def show_config(bind_key: str | None = None) -> None: # pyright: ignore[reportU
573580
"""Show and display all configurations with migrations enabled."""
574581
from rich.table import Table
575582

576-
ctx = click.get_current_context()
583+
ctx = _ensure_click_context()
577584

578585
# If bind_key is provided, filter to only that config
579586
if bind_key is not None:
580-
get_config_by_bind_key(cast("click.Context", ctx), bind_key)
587+
get_config_by_bind_key(ctx, bind_key)
581588
# Convert single config to list format for compatibility
582-
all_configs = cast("click.Context", ctx).obj["configs"]
589+
all_configs = ctx.obj["configs"]
583590
migration_configs = []
584591
for cfg in all_configs:
585592
config_name = cfg.bind_key
586593
if config_name == bind_key and hasattr(cfg, "migration_config") and cfg.migration_config:
587594
migration_configs.append((config_name, cfg))
588595
else:
589-
migration_configs = get_configs_with_migrations(cast("click.Context", ctx))
596+
migration_configs = get_configs_with_migrations(ctx)
590597

591598
if not migration_configs:
592599
console.print("[yellow]No configurations with migrations detected.[/]")

0 commit comments

Comments
 (0)