66from typing import TYPE_CHECKING , Any , cast
77
88import rich_click as click
9+ from click .core import ParameterSource
910
1011if TYPE_CHECKING :
1112 from rich_click import Group
@@ -80,6 +81,16 @@ def sqlspec_group(ctx: "click.Context", config: str, validate_config: bool) -> N
8081 return sqlspec_group
8182
8283
84+ def _ensure_click_context () -> "click.Context" :
85+ """Return the active Click context, raising if missing (for type-checkers)."""
86+
87+ context = click .get_current_context ()
88+ if context is None : # pragma: no cover - click guarantees context in commands
89+ msg = "SQLSpec CLI commands require an active Click context"
90+ raise RuntimeError (msg )
91+ return cast ("click.Context" , context )
92+
93+
8394def add_migration_commands (database_group : "Group | None" = None ) -> "Group" :
8495 """Add migration commands to the database group.
8596
@@ -270,17 +281,12 @@ def show_database_revision( # pyright: ignore[reportUnusedFunction]
270281 from sqlspec .migrations .commands import create_migration_commands
271282 from sqlspec .utils .sync_tools import run_
272283
273- ctx = click . get_current_context ()
284+ ctx = _ensure_click_context ()
274285
275286 async def _show_current_revision () -> None :
276287 # Check if this is a multi-config operation
277288 configs_to_process = process_multiple_configs (
278- cast ("click.Context" , ctx ),
279- bind_key ,
280- include ,
281- exclude ,
282- dry_run = False ,
283- operation_name = "show current revision" ,
289+ ctx , bind_key , include , exclude , dry_run = False , operation_name = "show current revision"
284290 )
285291
286292 if configs_to_process is not None :
@@ -300,7 +306,7 @@ async def _show_current_revision() -> None:
300306 console .print (f"[red]✗ Failed to get current revision for { config_name } : { e } [/]" )
301307 else :
302308 console .rule ("[yellow]Listing current revision[/]" , align = "left" )
303- sqlspec_config = get_config_by_bind_key (cast ( "click.Context" , ctx ) , bind_key )
309+ sqlspec_config = get_config_by_bind_key (ctx , bind_key )
304310 migration_commands = create_migration_commands (config = sqlspec_config )
305311 await maybe_await (migration_commands .current (verbose = verbose ))
306312
@@ -327,17 +333,12 @@ def downgrade_database( # pyright: ignore[reportUnusedFunction]
327333 from sqlspec .migrations .commands import create_migration_commands
328334 from sqlspec .utils .sync_tools import run_
329335
330- ctx = click . get_current_context ()
336+ ctx = _ensure_click_context ()
331337
332338 async def _downgrade_database () -> None :
333339 # Check if this is a multi-config operation
334340 configs_to_process = process_multiple_configs (
335- cast ("click.Context" , ctx ),
336- bind_key ,
337- include ,
338- exclude ,
339- dry_run = dry_run ,
340- operation_name = f"downgrade to { revision } " ,
341+ ctx , bind_key , include , exclude , dry_run = dry_run , operation_name = f"downgrade to { revision } "
341342 )
342343
343344 if configs_to_process is not None :
@@ -371,7 +372,7 @@ async def _downgrade_database() -> None:
371372 else Confirm .ask (f"Are you sure you want to downgrade the database to the `{ revision } ` revision?" )
372373 )
373374 if input_confirmed :
374- sqlspec_config = get_config_by_bind_key (cast ( "click.Context" , ctx ) , bind_key )
375+ sqlspec_config = get_config_by_bind_key (ctx , bind_key )
375376 migration_commands = create_migration_commands (config = sqlspec_config )
376377 await maybe_await (migration_commands .downgrade (revision = revision , dry_run = dry_run ))
377378
@@ -402,7 +403,7 @@ def upgrade_database( # pyright: ignore[reportUnusedFunction]
402403 from sqlspec .migrations .commands import create_migration_commands
403404 from sqlspec .utils .sync_tools import run_
404405
405- ctx = click . get_current_context ()
406+ ctx = _ensure_click_context ()
406407
407408 async def _upgrade_database () -> None :
408409 # Report execution mode when specified
@@ -411,7 +412,7 @@ async def _upgrade_database() -> None:
411412
412413 # Check if this is a multi-config operation
413414 configs_to_process = process_multiple_configs (
414- cast ( "click.Context" , ctx ) , bind_key , include , exclude , dry_run , operation_name = f"upgrade to { revision } "
415+ ctx , bind_key , include , exclude , dry_run , operation_name = f"upgrade to { revision } "
415416 )
416417
417418 if configs_to_process is not None :
@@ -449,7 +450,7 @@ async def _upgrade_database() -> None:
449450 )
450451 )
451452 if input_confirmed :
452- sqlspec_config = get_config_by_bind_key (cast ( "click.Context" , ctx ) , bind_key )
453+ sqlspec_config = get_config_by_bind_key (ctx , bind_key )
453454 migration_commands = create_migration_commands (config = sqlspec_config )
454455 await maybe_await (
455456 migration_commands .upgrade (revision = revision , auto_sync = not no_auto_sync , dry_run = dry_run )
@@ -465,10 +466,10 @@ def stamp(bind_key: str | None, revision: str) -> None: # pyright: ignore[repor
465466 from sqlspec .migrations .commands import create_migration_commands
466467 from sqlspec .utils .sync_tools import run_
467468
468- ctx = click . get_current_context ()
469+ ctx = _ensure_click_context ()
469470
470471 async def _stamp () -> None :
471- sqlspec_config = get_config_by_bind_key (cast ( "click.Context" , ctx ) , bind_key )
472+ sqlspec_config = get_config_by_bind_key (ctx , bind_key )
472473 migration_commands = create_migration_commands (config = sqlspec_config )
473474 await maybe_await (migration_commands .stamp (revision = revision ))
474475
@@ -488,7 +489,7 @@ def init_sqlspec( # pyright: ignore[reportUnusedFunction]
488489 from sqlspec .migrations .commands import create_migration_commands
489490 from sqlspec .utils .sync_tools import run_
490491
491- ctx = click . get_current_context ()
492+ ctx = _ensure_click_context ()
492493
493494 async def _init_sqlspec () -> None :
494495 console .rule ("[yellow]Initializing database migrations." , align = "left" )
@@ -498,11 +499,7 @@ async def _init_sqlspec() -> None:
498499 else Confirm .ask ("[bold]Are you sure you want initialize migrations for the project?[/]" )
499500 )
500501 if input_confirmed :
501- configs = (
502- [get_config_by_bind_key (cast ("click.Context" , ctx ), bind_key )]
503- if bind_key is not None
504- else cast ("click.Context" , ctx ).obj ["configs" ]
505- )
502+ configs = [get_config_by_bind_key (ctx , bind_key )] if bind_key is not None else ctx .obj ["configs" ]
506503
507504 for config in configs :
508505 migration_config = getattr (config , "migration_config" , {})
@@ -519,17 +516,25 @@ async def _init_sqlspec() -> None:
519516 )
520517 @bind_key_option
521518 @click .option ("-m" , "--message" , default = None , help = "Revision message" )
519+ @click .option (
520+ "--format" ,
521+ "--file-type" ,
522+ "file_format" ,
523+ type = click .Choice (["sql" , "py" ]),
524+ default = None ,
525+ help = "File format for the generated migration (defaults to template profile)" ,
526+ )
522527 @no_prompt_option
523528 def create_revision ( # pyright: ignore[reportUnusedFunction]
524- bind_key : str | None , message : str | None , no_prompt : bool
529+ bind_key : str | None , message : str | None , file_format : str | None , no_prompt : bool
525530 ) -> None :
526531 """Create a new database revision."""
527532 from rich .prompt import Prompt
528533
529534 from sqlspec .migrations .commands import create_migration_commands
530535 from sqlspec .utils .sync_tools import run_
531536
532- ctx = click . get_current_context ()
537+ ctx = _ensure_click_context ()
533538
534539 async def _create_revision () -> None :
535540 console .rule ("[yellow]Creating new migration revision[/]" , align = "left" )
@@ -539,9 +544,11 @@ async def _create_revision() -> None:
539544 "new migration" if no_prompt else Prompt .ask ("Please enter a message describing this revision" )
540545 )
541546
542- sqlspec_config = get_config_by_bind_key (cast ("click.Context" , ctx ), bind_key )
547+ sqlspec_config = get_config_by_bind_key (ctx , bind_key )
548+ param_source = ctx .get_parameter_source ("file_format" )
549+ effective_format = None if param_source is ParameterSource .DEFAULT else file_format
543550 migration_commands = create_migration_commands (config = sqlspec_config )
544- await maybe_await (migration_commands .revision (message = message_text ))
551+ await maybe_await (migration_commands .revision (message = message_text , file_type = effective_format ))
545552
546553 run_ (_create_revision )()
547554
@@ -557,11 +564,11 @@ def fix_migrations( # pyright: ignore[reportUnusedFunction]
557564 from sqlspec .migrations .commands import create_migration_commands
558565 from sqlspec .utils .sync_tools import run_
559566
560- ctx = click . get_current_context ()
567+ ctx = _ensure_click_context ()
561568
562569 async def _fix_migrations () -> None :
563570 console .rule ("[yellow]Migration Fix Command[/]" , align = "left" )
564- sqlspec_config = get_config_by_bind_key (cast ( "click.Context" , ctx ) , bind_key )
571+ sqlspec_config = get_config_by_bind_key (ctx , bind_key )
565572 migration_commands = create_migration_commands (config = sqlspec_config )
566573 await maybe_await (migration_commands .fix (dry_run = dry_run , update_database = not no_database , yes = yes ))
567574
@@ -573,20 +580,20 @@ def show_config(bind_key: str | None = None) -> None: # pyright: ignore[reportU
573580 """Show and display all configurations with migrations enabled."""
574581 from rich .table import Table
575582
576- ctx = click . get_current_context ()
583+ ctx = _ensure_click_context ()
577584
578585 # If bind_key is provided, filter to only that config
579586 if bind_key is not None :
580- get_config_by_bind_key (cast ( "click.Context" , ctx ) , bind_key )
587+ get_config_by_bind_key (ctx , bind_key )
581588 # Convert single config to list format for compatibility
582- all_configs = cast ( "click.Context" , ctx ) .obj ["configs" ]
589+ all_configs = ctx .obj ["configs" ]
583590 migration_configs = []
584591 for cfg in all_configs :
585592 config_name = cfg .bind_key
586593 if config_name == bind_key and hasattr (cfg , "migration_config" ) and cfg .migration_config :
587594 migration_configs .append ((config_name , cfg ))
588595 else :
589- migration_configs = get_configs_with_migrations (cast ( "click.Context" , ctx ) )
596+ migration_configs = get_configs_with_migrations (ctx )
590597
591598 if not migration_configs :
592599 console .print ("[yellow]No configurations with migrations detected.[/]" )
0 commit comments