From 13110e88b7b4209609bd8265541439491c128add Mon Sep 17 00:00:00 2001 From: Mario Alejandro Montoya Cortes Date: Thu, 27 Nov 2025 15:24:05 -0500 Subject: [PATCH 1/3] Call parameterized views from sql #3489, parsing --- crates/expr/src/check.rs | 12 ++--- crates/expr/src/errors.rs | 6 +++ crates/sql-parser/src/ast/mod.rs | 29 +++++++++-- crates/sql-parser/src/ast/sql.rs | 1 + crates/sql-parser/src/ast/sub.rs | 1 + crates/sql-parser/src/parser/errors.rs | 7 +++ crates/sql-parser/src/parser/mod.rs | 66 ++++++++++++++++++++------ crates/sql-parser/src/parser/sql.rs | 22 +++++++++ crates/sql-parser/src/parser/sub.rs | 22 +++++++++ 9 files changed, 141 insertions(+), 25 deletions(-) diff --git a/crates/expr/src/check.rs b/crates/expr/src/check.rs index 9c4717f9a9f..6f7cb8949e0 100644 --- a/crates/expr/src/check.rs +++ b/crates/expr/src/check.rs @@ -15,7 +15,7 @@ use spacetimedb_sql_parser::{ }; use super::{ - errors::{DuplicateName, TypingError, Unresolved, Unsupported}, + errors::{DuplicateName, FunctionCall, TypingError, Unresolved, Unsupported}, expr::RelExpr, type_expr, type_proj, type_select, }; @@ -78,12 +78,8 @@ pub trait TypeChecker { delta: None, }); - for SqlJoin { - var: SqlIdent(name), - alias: SqlIdent(alias), - on, - } in joins - { + for SqlJoin { from, on } in joins { + let (SqlIdent(name), SqlIdent(alias)) = from.into_name_alias(); // Check for duplicate aliases if vars.contains_key(&alias) { return Err(DuplicateName(alias.into_string()).into()); @@ -113,6 +109,8 @@ pub trait TypeChecker { Ok(join) } + // TODO: support function calls in FROM clause + SqlFrom::FuncCall(_, _) => Err(FunctionCall.into()), } } diff --git a/crates/expr/src/errors.rs b/crates/expr/src/errors.rs index e569c0d134a..9b4894ab1ac 100644 --- a/crates/expr/src/errors.rs +++ b/crates/expr/src/errors.rs @@ -128,6 +128,10 @@ pub struct DmlOnView { pub view_name: Box, } +#[derive(Debug, Error)] +#[error("Function calls are not supported")] +pub struct FunctionCall; + #[derive(Error, Debug)] pub enum TypingError { #[error(transparent)] @@ -157,4 +161,6 @@ pub enum TypingError { DuplicateName(#[from] DuplicateName), #[error(transparent)] FilterReturnType(#[from] FilterReturnType), + #[error(transparent)] + FunctionCall(#[from] FunctionCall), } diff --git a/crates/sql-parser/src/ast/mod.rs b/crates/sql-parser/src/ast/mod.rs index 776d4fc5006..671c6c342b8 100644 --- a/crates/sql-parser/src/ast/mod.rs +++ b/crates/sql-parser/src/ast/mod.rs @@ -6,11 +6,12 @@ use sqlparser::ast::Ident; pub mod sql; pub mod sub; -/// The FROM clause is either a relvar or a JOIN +/// The FROM clause is either a relvar, a JOIN, or a function call #[derive(Debug)] pub enum SqlFrom { Expr(SqlIdent, SqlIdent), Join(SqlIdent, SqlIdent, Vec), + FuncCall(SqlFuncCall, SqlIdent), } impl SqlFrom { @@ -22,11 +23,26 @@ impl SqlFrom { } } +/// A source in a FROM clause, restricted to a single relvar or function call +#[derive(Debug)] +pub enum SqlFromSource { + Expr(SqlIdent, SqlIdent), + FuncCall(SqlFuncCall, SqlIdent), +} + +impl SqlFromSource { + pub fn into_name_alias(self) -> (SqlIdent, SqlIdent) { + match self { + Self::Expr(name, alias) => (name, alias), + Self::FuncCall(func, alias) => (func.name, alias), + } + } +} + /// An inner join in a FROM clause #[derive(Debug)] pub struct SqlJoin { - pub var: SqlIdent, - pub alias: SqlIdent, + pub from: SqlFromSource, pub on: Option, } @@ -247,3 +263,10 @@ impl Display for LogOp { } } } + +/// A SQL function call +#[derive(Debug)] +pub struct SqlFuncCall { + pub name: SqlIdent, + pub args: Vec, +} diff --git a/crates/sql-parser/src/ast/sql.rs b/crates/sql-parser/src/ast/sql.rs index 567b5ec5328..4e6b5176da3 100644 --- a/crates/sql-parser/src/ast/sql.rs +++ b/crates/sql-parser/src/ast/sql.rs @@ -78,6 +78,7 @@ impl SqlSelect { ..self }, SqlFrom::Join(..) => self, + SqlFrom::FuncCall(..) => self, } } diff --git a/crates/sql-parser/src/ast/sub.rs b/crates/sql-parser/src/ast/sub.rs index bd6fde0d98c..6ba9db11982 100644 --- a/crates/sql-parser/src/ast/sub.rs +++ b/crates/sql-parser/src/ast/sub.rs @@ -21,6 +21,7 @@ impl SqlSelect { from: self.from, }, SqlFrom::Join(..) => self, + SqlFrom::FuncCall(..) => self, } } diff --git a/crates/sql-parser/src/parser/errors.rs b/crates/sql-parser/src/parser/errors.rs index 953a031b8b8..afa053e36b3 100644 --- a/crates/sql-parser/src/parser/errors.rs +++ b/crates/sql-parser/src/parser/errors.rs @@ -1,5 +1,6 @@ use std::fmt::Display; +use sqlparser::ast::FunctionArg; use sqlparser::{ ast::{ BinaryOperator, Expr, Function, ObjectName, Query, Select, SelectItem, SetExpr, TableFactor, TableWithJoins, @@ -77,6 +78,12 @@ pub enum SqlUnsupported { Empty, #[error("Names must be qualified when using joins")] UnqualifiedNames, + #[error("Unsupported function argument: {0}")] + FuncArg(FunctionArg), + #[error("Unsupported call to table-valued function with empty params. Use `select * from table_function` syntax instead: {0}")] + TableFunctionNoParams(String), + #[error("Unsupported JOIN with table-valued function: {0}")] + JoinTableFunction(String), } impl SqlUnsupported { diff --git a/crates/sql-parser/src/parser/mod.rs b/crates/sql-parser/src/parser/mod.rs index 9e6e5642bda..46c463c2b76 100644 --- a/crates/sql-parser/src/parser/mod.rs +++ b/crates/sql-parser/src/parser/mod.rs @@ -6,7 +6,8 @@ use sqlparser::ast::{ }; use crate::ast::{ - BinOp, LogOp, Parameter, Project, ProjectElem, ProjectExpr, SqlExpr, SqlFrom, SqlIdent, SqlJoin, SqlLiteral, + BinOp, LogOp, Parameter, Project, ProjectElem, ProjectExpr, SqlExpr, SqlFrom, SqlFromSource, SqlFuncCall, SqlIdent, + SqlJoin, SqlLiteral, }; pub mod errors; @@ -34,11 +35,15 @@ trait RelParser { return Err(SqlUnsupported::ImplicitJoins.into()); } let TableWithJoins { relation, joins } = tables.swap_remove(0); - let (name, alias) = Self::parse_relvar(relation)?; - if joins.is_empty() { - return Ok(SqlFrom::Expr(name, alias)); + match Self::parse_relvar(relation)? { + SqlFromSource::Expr(name, alias) => { + if joins.is_empty() { + return Ok(SqlFrom::Expr(name, alias)); + } + Ok(SqlFrom::Join(name, alias, Self::parse_joins(joins)?)) + } + SqlFromSource::FuncCall(func_call, alias) => Ok(SqlFrom::FuncCall(func_call, alias)), } - Ok(SqlFrom::Join(name, alias, Self::parse_joins(joins)?)) } /// Parse a sequence of JOIN clauses @@ -48,10 +53,11 @@ trait RelParser { /// Parse a single JOIN clause fn parse_join(join: Join) -> SqlParseResult { - let (var, alias) = Self::parse_relvar(join.relation)?; + let from = Self::parse_relvar(join.relation)?; + match join.join_operator { - JoinOperator::CrossJoin => Ok(SqlJoin { var, alias, on: None }), - JoinOperator::Inner(JoinConstraint::None) => Ok(SqlJoin { var, alias, on: None }), + JoinOperator::CrossJoin => Ok(SqlJoin { from, on: None }), + JoinOperator::Inner(JoinConstraint::None) => Ok(SqlJoin { from, on: None }), JoinOperator::Inner(JoinConstraint::On(Expr::BinaryOp { left, op: BinaryOperator::Eq, @@ -60,8 +66,7 @@ trait RelParser { && matches!(*right, Expr::Identifier(..) | Expr::CompoundIdentifier(..)) => { Ok(SqlJoin { - var, - alias, + from, on: Some(parse_expr( Expr::BinaryOp { left, @@ -76,32 +81,63 @@ trait RelParser { } } + /// Parse a function call + fn parse_func_call(name: SqlIdent, args: Vec) -> SqlParseResult { + if args.is_empty() { + return Err(SqlUnsupported::TableFunctionNoParams(name.0.into()).into()); + } + let args = args + .into_iter() + .map(|arg| match arg.clone() { + FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => match parse_expr(expr, 0) { + Ok(SqlExpr::Lit(lit)) => Ok(lit), + _ => Err(SqlUnsupported::FuncArg(arg).into()), + }, + _ => Err(SqlUnsupported::FuncArg(arg.clone()).into()), + }) + .collect::>()?; + Ok(SqlFuncCall { name, args }) + } + /// Parse a table reference in a FROM clause - fn parse_relvar(expr: TableFactor) -> SqlParseResult<(SqlIdent, SqlIdent)> { + fn parse_relvar(expr: TableFactor) -> SqlParseResult { match expr { // Relvar no alias TableFactor::Table { name, alias: None, - args: None, + args, with_hints, version: None, partitions, } if with_hints.is_empty() && partitions.is_empty() => { let name = parse_ident(name)?; let alias = name.clone(); - Ok((name, alias)) + + if let Some(args) = args { + Ok(SqlFromSource::FuncCall(Self::parse_func_call(name, args)?, alias)) + } else { + Ok(SqlFromSource::Expr(name, alias)) + } } // Relvar with alias TableFactor::Table { name, alias: Some(TableAlias { name: alias, columns }), - args: None, + args, with_hints, version: None, partitions, } if with_hints.is_empty() && partitions.is_empty() && columns.is_empty() => { - Ok((parse_ident(name)?, alias.into())) + let args = args.filter(|v| !v.is_empty()); + if let Some(args) = args { + Ok(SqlFromSource::FuncCall( + Self::parse_func_call(parse_ident(name)?, args)?, + alias.into(), + )) + } else { + Ok(SqlFromSource::Expr(parse_ident(name)?, alias.into())) + } } _ => Err(SqlUnsupported::From(expr).into()), } diff --git a/crates/sql-parser/src/parser/sql.rs b/crates/sql-parser/src/parser/sql.rs index a1eb5078726..9dee72042da 100644 --- a/crates/sql-parser/src/parser/sql.rs +++ b/crates/sql-parser/src/parser/sql.rs @@ -68,9 +68,18 @@ //! | SUM '(' columnExpr ')' AS ident //! ; //! +//! paramExpr +//! = literal +//! ; +//! +//! functionCall +//! = ident '(' [ paramExpr { ',' paramExpr } ] ')' +//! ; +//! //! relation //! = table //! | '(' query ')' +//! | functionCall //! | relation [ [AS] ident ] { [INNER] JOIN relation [ [AS] ident ] ON predicate } //! ; //! @@ -442,6 +451,11 @@ mod tests { "select a from t where x = :sender", "select count(*) as n from t", "select count(*) as n from t join s on t.id = s.id where s.x = 1", + "select * from sample as s", + "select * from sample(1, 'abc', true, 0xFF, 0.1)", + "select * from sample(1, 'abc', true, 0xFF, 0.1) as s", + "select * from t join sample(1) on t.id = sample.id", + "select * from t join sample(1) as s on t.id = s.id", "insert into t values (1, 2)", "delete from t", "delete from t where a = 1", @@ -463,6 +477,14 @@ mod tests { "select a from where b = 1", // Empty WHERE "select a from t where", + // Function call params are not literals + "select * from sample(a, b)", + // Function call without params + "select * from sample()", + // Nested function call + "select * from sample(sample(1))", + // Function call in JOIN ON + "select * from t join sample(1) on t.id = sample(1).id", // Empty GROUP BY "select a, count(*) from t group by", // Aggregate without alias diff --git a/crates/sql-parser/src/parser/sub.rs b/crates/sql-parser/src/parser/sub.rs index 6a8ef34a1e8..a60cbb81132 100644 --- a/crates/sql-parser/src/parser/sub.rs +++ b/crates/sql-parser/src/parser/sub.rs @@ -10,9 +10,18 @@ //! | ident '.' STAR //! ; //! +//! paramExpr +//! = literal +//! ; +//! +//! functionCall +//! = ident '(' [ paramExpr { ',' paramExpr } ] ')' +//! ; +//! //! relation //! = table //! | '(' query ')' +//! | functionCall //! | relation [ [AS] ident ] { [INNER] JOIN relation [ [AS] ident ] ON predicate } //! ; //! @@ -162,6 +171,14 @@ mod tests { "", "select distinct a from t", "select * from (select * from t) join (select * from s) on a = b", + // Function call params are not literals + "select * from sample(a, b)", + // Function call without params + "select * from sample()", + // Nested function call + "select * from sample(sample(1))", + // Function call in JOIN ON + "select * from t join sample(1) on t.id = sample(1).id", ] { assert!(parse_subscription(sql).is_err()); } @@ -178,6 +195,11 @@ mod tests { "select t.* from t join s on t.c = s.d", "select a.* from t as a join s as b on a.c = b.d", "select * from t where x = :sender", + "select * from sample as s", + "select * from sample(1, 'abc', true, 0xFF, 0.1)", + "select * from sample(1, 'abc', true, 0xFF, 0.1) as s", + "select * from t join sample(1) on t.id = sample.id", + "select * from t join sample(1) as s on t.id = s.id", ] { assert!(parse_subscription(sql).is_ok()); } From d2dde1fb5b258389bc4cae5bf429eac2ca1b9a49 Mon Sep 17 00:00:00 2001 From: Mario Alejandro Montoya Cortes Date: Fri, 28 Nov 2025 12:51:25 -0500 Subject: [PATCH 2/3] Call parameterized views from sql #3489, plan and execution --- crates/core/src/db/relational_db.rs | 7 +- crates/core/src/sql/ast.rs | 8 +- crates/core/src/sql/execute.rs | 84 +++++- .../subscription/module_subscription_actor.rs | 2 +- crates/core/src/subscription/subscription.rs | 2 +- .../src/locking_tx_datastore/mut_tx.rs | 30 ++ .../src/locking_tx_datastore/state_view.rs | 20 +- crates/datastore/src/system_tables.rs | 7 + crates/expr/src/check.rs | 282 ++++++++++++++---- crates/expr/src/errors.rs | 17 +- crates/expr/src/expr.rs | 71 ++++- crates/expr/src/lib.rs | 17 +- crates/expr/src/rls.rs | 39 +-- crates/expr/src/statement.rs | 183 +++++++----- crates/physical-plan/src/compile.rs | 2 +- crates/physical-plan/src/plan.rs | 101 ++++++- crates/query/src/lib.rs | 18 +- crates/schema/src/def.rs | 10 + crates/schema/src/schema.rs | 65 ++-- crates/sql-parser/src/ast/mod.rs | 8 +- modules/module-test/src/lib.rs | 5 + 21 files changed, 746 insertions(+), 232 deletions(-) diff --git a/crates/core/src/db/relational_db.rs b/crates/core/src/db/relational_db.rs index 93bf29b8091..dbe657b2918 100644 --- a/crates/core/src/db/relational_db.rs +++ b/crates/core/src/db/relational_db.rs @@ -1122,6 +1122,10 @@ impl RelationalDB { Ok(tx.create_view(module_def, view_def)?) } + pub fn create_or_get_params(&self, tx: &mut MutTx, params: &ProductValue) -> Result { + Ok(tx.create_or_get_params(params)?) + } + pub fn drop_view(&self, tx: &mut MutTx, view_id: ViewId) -> Result<(), DBError> { Ok(tx.drop_view(view_id)?) } @@ -2217,6 +2221,7 @@ pub mod tests_utils { db: &RelationalDB, name: &str, schema: &[(&str, AlgebraicType)], + params: ProductType, is_anonymous: bool, ) -> Result<(ViewId, TableId), DBError> { let mut builder = RawModuleDefV9Builder::new(); @@ -2234,7 +2239,7 @@ pub mod tests_utils { 0, true, is_anonymous, - ProductType::unit(), + params, AlgebraicType::array(AlgebraicType::Ref(type_ref)), ); diff --git a/crates/core/src/sql/ast.rs b/crates/core/src/sql/ast.rs index 26c52d9b126..a11775331ce 100644 --- a/crates/core/src/sql/ast.rs +++ b/crates/core/src/sql/ast.rs @@ -4,11 +4,11 @@ use anyhow::Context; use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap}; use spacetimedb_datastore::locking_tx_datastore::state_view::StateView; use spacetimedb_datastore::system_tables::{StRowLevelSecurityFields, ST_ROW_LEVEL_SECURITY_ID}; -use spacetimedb_expr::check::SchemaView; +use spacetimedb_expr::check::{SchemaView, TypingResult}; use spacetimedb_expr::statement::compile_sql_stmt; use spacetimedb_lib::identity::AuthCtx; -use spacetimedb_primitives::{ColId, TableId}; -use spacetimedb_sats::{AlgebraicType, AlgebraicValue}; +use spacetimedb_primitives::{ArgId, ColId, TableId}; +use spacetimedb_sats::{AlgebraicType, AlgebraicValue, ProductValue}; use spacetimedb_schema::def::error::RelationError; use spacetimedb_schema::relation::{ColExpr, FieldName}; use spacetimedb_schema::schema::{ColumnSchema, TableOrViewSchema, TableSchema}; @@ -477,7 +477,7 @@ fn compile_where(table: &From, filter: Option) -> Result { - tx: &'a T, + pub(crate) tx: &'a T, auth: &'a AuthCtx, } diff --git a/crates/core/src/sql/execute.rs b/crates/core/src/sql/execute.rs index 7b0d50ff421..2e2e2337509 100644 --- a/crates/core/src/sql/execute.rs +++ b/crates/core/src/sql/execute.rs @@ -1,7 +1,8 @@ +use std::sync::Arc; use std::time::Duration; use super::ast::SchemaViewer; -use crate::db::relational_db::{RelationalDB, Tx}; +use crate::db::relational_db::{MutTx, RelationalDB, Tx}; use crate::energy::EnergyQuanta; use crate::error::DBError; use crate::estimation::estimate_rows_scanned; @@ -19,13 +20,18 @@ use anyhow::anyhow; use spacetimedb_datastore::execution_context::Workload; use spacetimedb_datastore::locking_tx_datastore::state_view::StateView; use spacetimedb_datastore::traits::IsolationLevel; +use spacetimedb_expr::check::SchemaView; +use spacetimedb_expr::errors::TypingError; +use spacetimedb_expr::expr::CallParams; use spacetimedb_expr::statement::Statement; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::metrics::ExecutionMetrics; use spacetimedb_lib::Timestamp; use spacetimedb_lib::{AlgebraicType, ProductType, ProductValue}; +use spacetimedb_primitives::{ArgId, TableId}; use spacetimedb_query::{compile_sql_stmt, execute_dml_stmt, execute_select_stmt}; use spacetimedb_schema::relation::FieldName; +use spacetimedb_schema::schema::TableOrViewSchema; use spacetimedb_vm::eval::run_ast; use spacetimedb_vm::expr::{CodeResult, CrudExpr, Expr}; use spacetimedb_vm::relation::MemTable; @@ -185,6 +191,19 @@ pub struct SqlResult { pub metrics: ExecutionMetrics, } +struct DbParams<'a> { + db: &'a RelationalDB, + tx: &'a mut MutTx, +} + +impl CallParams for DbParams<'_> { + fn create_or_get_param(&mut self, param: &ProductValue) -> Result { + self.db + .create_or_get_params(self.tx, ¶m) + .map_err(|err| TypingError::Other(err.into())) + } +} + /// Run the `SQL` string using the `auth` credentials pub async fn run( db: &RelationalDB, @@ -196,9 +215,20 @@ pub async fn run( ) -> Result { // We parse the sql statement in a mutable transaction. // If it turns out to be a query, we downgrade the tx. - let (tx, stmt) = db.with_auto_rollback(db.begin_mut_tx(IsolationLevel::Serializable, Workload::Sql), |tx| { - compile_sql_stmt(sql_text, &SchemaViewer::new(tx, &auth), &auth) - })?; + let (tx, stmt) = + db.with_auto_rollback( + db.begin_mut_tx(IsolationLevel::Serializable, Workload::Sql), + |tx| match compile_sql_stmt(sql_text, &SchemaViewer::new(tx, &auth), &auth) { + Ok(Statement::Select(mut stmt)) => { + stmt.for_each_fun_call(&mut |param| { + db.create_or_get_params(tx, ¶m) + .map_err(|err| TypingError::Other(err.into())) + })?; + Ok(Statement::Select(stmt)) + } + result => result, + }, + )?; let mut metrics = ExecutionMetrics::default(); @@ -345,7 +375,8 @@ pub(crate) mod tests { use itertools::Itertools; use pretty_assertions::assert_eq; use spacetimedb_datastore::system_tables::{ - StRowLevelSecurityRow, StTableFields, ST_ROW_LEVEL_SECURITY_ID, ST_TABLE_ID, ST_TABLE_NAME, + StRowLevelSecurityRow, StTableFields, ST_RESERVED_SEQUENCE_RANGE, ST_ROW_LEVEL_SECURITY_ID, ST_TABLE_ID, + ST_TABLE_NAME, }; use spacetimedb_lib::bsatn::ToBsatn; use spacetimedb_lib::db::auth::{StAccess, StTableType}; @@ -958,7 +989,7 @@ pub(crate) mod tests { let db = TestDB::in_memory()?; let schema = [("a", AlgebraicType::U8), ("b", AlgebraicType::U8)]; - let (_, table_id) = tests_utils::create_view_for_test(&db, "my_view", &schema, false)?; + let (_, table_id) = tests_utils::create_view_for_test(&db, "my_view", &schema, ProductType::unit(), false)?; with_auto_commit(&db, |tx| -> Result<_, DBError> { tests_utils::insert_into_view(&db, tx, table_id, Some(identity_from_u8(1)), product![0u8, 1u8])?; @@ -979,7 +1010,7 @@ pub(crate) mod tests { let db = TestDB::in_memory()?; let schema = [("a", AlgebraicType::U8), ("b", AlgebraicType::U8)]; - let (_, table_id) = tests_utils::create_view_for_test(&db, "my_view", &schema, true)?; + let (_, table_id) = tests_utils::create_view_for_test(&db, "my_view", &schema, ProductType::unit(), true)?; with_auto_commit(&db, |tx| -> Result<_, DBError> { tests_utils::insert_into_view(&db, tx, table_id, None, product![0u8, 1u8])?; @@ -1000,7 +1031,7 @@ pub(crate) mod tests { let db = TestDB::in_memory()?; let schema = [("a", AlgebraicType::U8), ("b", AlgebraicType::U8)]; - let (_, v_id) = tests_utils::create_view_for_test(&db, "v", &schema, false)?; + let (_, v_id) = tests_utils::create_view_for_test(&db, "v", &schema, ProductType::unit(), false)?; let schema = [("c", AlgebraicType::U8), ("d", AlgebraicType::U8)]; let t_id = db.create_table_for_test("t", &schema, &[0.into()])?; @@ -1060,10 +1091,10 @@ pub(crate) mod tests { let db = TestDB::in_memory()?; let schema = [("a", AlgebraicType::U8), ("b", AlgebraicType::U8)]; - let (_, u_id) = tests_utils::create_view_for_test(&db, "u", &schema, false)?; + let (_, u_id) = tests_utils::create_view_for_test(&db, "u", &schema, ProductType::unit(), false)?; let schema = [("c", AlgebraicType::U8), ("d", AlgebraicType::U8)]; - let (_, v_id) = tests_utils::create_view_for_test(&db, "v", &schema, false)?; + let (_, v_id) = tests_utils::create_view_for_test(&db, "v", &schema, ProductType::unit(), false)?; with_auto_commit(&db, |tx| -> Result<_, DBError> { tests_utils::insert_into_view(&db, tx, u_id, Some(identity_from_u8(1)), product![0u8, 1u8])?; @@ -1574,4 +1605,37 @@ pub(crate) mod tests { Ok(()) } + + // Verify calling views with params + #[test] + fn test_view_params() -> ResultTest<()> { + let db = TestDB::durable()?; + let schema = [("a", AlgebraicType::U8), ("b", AlgebraicType::I64)]; + let (_view_id, table_id) = tests_utils::create_view_for_test( + &db, + "my_view", + &schema, + ProductType::from([("x", AlgebraicType::U8)]), + true, + )?; + let arg_id = ST_RESERVED_SEQUENCE_RANGE as u64; + + with_auto_commit(&db, |tx| -> Result<_, DBError> { + tests_utils::insert_into_view(&db, tx, table_id, None, product![arg_id + 1, 0u8, 1i64])?; + tests_utils::insert_into_view(&db, tx, table_id, None, product![arg_id, 1u8, 2i64])?; + Ok(()) + })?; + + assert_eq!( + run_for_testing(&db, "select * from my_view(1)")?, + vec![product![1u8, 2i64]] + ); + + assert_eq!( + run_for_testing(&db, "select * from my_view(2)")?, + vec![product![0u8, 1i64]] + ); + + Ok(()) + } } diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index 703873aebcb..0396a12f10f 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -327,7 +327,7 @@ impl ModuleSubscriptions { let view_info = plans .first() .and_then(|plan| plan.return_table()) - .and_then(|schema| schema.view_info); + .and_then(|schema| schema.view_info.clone()); let num_cols = plans .first() diff --git a/crates/core/src/subscription/subscription.rs b/crates/core/src/subscription/subscription.rs index d96aee8fc93..1fb13e8ee56 100644 --- a/crates/core/src/subscription/subscription.rs +++ b/crates/core/src/subscription/subscription.rs @@ -609,7 +609,7 @@ impl AuthAccess for ExecutionSet { } } -/// Querieshttps://github.com/clockworklabs/SpacetimeDBPrivate/pull/2207 all the [`StTableType::User`] tables *right now* +/// Queries https://github.com/clockworklabs/SpacetimeDBPrivate/pull/2207 all the [`StTableType::User`] tables *right now* /// and turns them into [`QueryExpr`], /// the moral equivalent of `SELECT * FROM table`. pub(crate) fn get_all( diff --git a/crates/datastore/src/locking_tx_datastore/mut_tx.rs b/crates/datastore/src/locking_tx_datastore/mut_tx.rs index 0d83d6824be..7f7292c556f 100644 --- a/crates/datastore/src/locking_tx_datastore/mut_tx.rs +++ b/crates/datastore/src/locking_tx_datastore/mut_tx.rs @@ -8,6 +8,8 @@ use super::{ tx_state::{IndexIdMap, PendingSchemaChange, TxState, TxTableForInsertion}, SharedMutexGuard, SharedWriteGuard, }; +use crate::error::DatastoreError; +use crate::system_tables::{StViewArgFields, StViewArgRow, ST_VIEW_ARG_ID}; use crate::{ error::ViewError, locking_tx_datastore::state_view::EqOnColumn, @@ -48,6 +50,7 @@ use spacetimedb_lib::{ use spacetimedb_primitives::{ col_list, ArgId, ColId, ColList, ColSet, ConstraintId, IndexId, ScheduleId, SequenceId, TableId, ViewFnPtr, ViewId, }; +use spacetimedb_sats::bsatn::ToBsatn; use spacetimedb_sats::{ bsatn::{self, to_writer, DecodeError, Deserializer}, de::{DeserializeSeed, WithBound}, @@ -589,6 +592,33 @@ impl MutTxId { Ok(()) } + pub fn get_params(&self, params: ProductValue) -> Result> { + self.iter_by_col_eq(ST_VIEW_ARG_ID, StViewArgFields::Bytes, &AlgebraicValue::Product(params))? + .next() + .map(|row| Ok(ArgId(StViewArgRow::try_from(row)?.id))) + .transpose() + } + + /// Create parameters for a view, storing the values in `st_view_arg`. + pub fn create_or_get_params(&mut self, params: &ProductValue) -> Result { + if let Some(arg_id) = self.get_params(params.clone())? { + return Ok(arg_id); + } + let row = StViewArgRow { + id: ArgId::SENTINEL.0, + bytes: params + .to_bsatn_vec() + .map_err(|err| DatastoreError::ReadViaBsatnError(err.into()))? + .into(), + }; + let arg_id = self + .insert_via_serialize_bsatn(ST_VIEW_ARG_ID, &row)? + .1 + .collapse() + .read_col::(StViewArgFields::Id)?; + Ok(ArgId(arg_id)) + } + /// Create a table. /// /// Requires: diff --git a/crates/datastore/src/locking_tx_datastore/state_view.rs b/crates/datastore/src/locking_tx_datastore/state_view.rs index 3244a5ca92f..b2557eba643 100644 --- a/crates/datastore/src/locking_tx_datastore/state_view.rs +++ b/crates/datastore/src/locking_tx_datastore/state_view.rs @@ -5,8 +5,8 @@ use crate::locking_tx_datastore::mut_tx::{IndexScanPoint, IndexScanRanged}; use crate::system_tables::{ ConnectionIdViaU128, StColumnFields, StColumnRow, StConnectionCredentialsFields, StConnectionCredentialsRow, StConstraintFields, StConstraintRow, StIndexFields, StIndexRow, StScheduledFields, StScheduledRow, - StSequenceFields, StSequenceRow, StTableFields, StTableRow, StViewFields, StViewParamFields, StViewRow, - SystemTable, ST_COLUMN_ID, ST_CONNECTION_CREDENTIALS_ID, ST_CONSTRAINT_ID, ST_INDEX_ID, ST_SCHEDULED_ID, + StSequenceFields, StSequenceRow, StTableFields, StTableRow, StViewFields, StViewParamFields, StViewParamRow, + StViewRow, SystemTable, ST_COLUMN_ID, ST_CONNECTION_CREDENTIALS_ID, ST_CONSTRAINT_ID, ST_INDEX_ID, ST_SCHEDULED_ID, ST_SEQUENCE_ID, ST_TABLE_ID, ST_VIEW_ID, ST_VIEW_PARAM_ID, }; use anyhow::anyhow; @@ -14,6 +14,8 @@ use core::ops::RangeBounds; use spacetimedb_lib::ConnectionId; use spacetimedb_primitives::{ColList, TableId}; use spacetimedb_sats::AlgebraicValue; +use spacetimedb_schema::def::ViewParamDefSimple; +use spacetimedb_schema::identifier::Identifier; use spacetimedb_schema::schema::{ColumnSchema, TableSchema, ViewDefInfo}; use spacetimedb_table::table::IndexScanPointIter; use spacetimedb_table::{ @@ -133,14 +135,20 @@ pub trait StateView { .map(|mut iter| { iter.next().map(|row| -> Result<_> { let row = StViewRow::try_from(row)?; - let has_args = self + let args: Vec<_> = self .iter_by_col_eq(ST_VIEW_PARAM_ID, StViewParamFields::ViewId, &row.view_id.into())? - .next() - .is_some(); + .map(|param_row| { + let param_row = StViewParamRow::try_from(param_row)?; + Ok(ViewParamDefSimple { + name: Identifier::new(param_row.param_name).expect("valid identifier"), + ty: param_row.param_type.0, + }) + }) + .collect::>>()?; Ok(ViewDefInfo { view_id: row.view_id, - has_args, + args, is_anonymous: row.is_anonymous, }) }) diff --git a/crates/datastore/src/system_tables.rs b/crates/datastore/src/system_tables.rs index 86c343a8b33..9d8b5bda757 100644 --- a/crates/datastore/src/system_tables.rs +++ b/crates/datastore/src/system_tables.rs @@ -991,6 +991,13 @@ pub struct StViewParamRow { pub param_type: AlgebraicTypeViaBytes, } +impl TryFrom> for StViewParamRow { + type Error = DatastoreError; + fn try_from(row: RowRef<'_>) -> Result { + read_via_bsatn(row) + } +} + /// System table [ST_VIEW_SUB_NAME] /// /// | view_id | arg_id | identity | num_subscribers | has_subscribers | last_called | diff --git a/crates/expr/src/check.rs b/crates/expr/src/check.rs index 6f7cb8949e0..14e4c59070b 100644 --- a/crates/expr/src/check.rs +++ b/crates/expr/src/check.rs @@ -1,25 +1,27 @@ +use std::cmp::max; use std::collections::HashMap; use std::ops::{Deref, DerefMut}; use std::sync::Arc; -use crate::expr::LeftDeepJoin; -use crate::expr::{Expr, ProjectList, ProjectName, Relvar}; +use super::{ + errors::{DuplicateName, TypingError, Unresolved, Unsupported}, + expr::RelExpr, + type_expr, type_proj, type_select, +}; +use crate::errors::{TableFunc, UnexpectedFunctionType}; +use crate::expr::{Expr, LeftDeepJoin, ProjectList, ProjectName, Relvar}; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::AlgebraicType; -use spacetimedb_primitives::TableId; +use spacetimedb_primitives::{ArgId, TableId}; +use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type; +use spacetimedb_sats::ProductValue; use spacetimedb_schema::schema::TableOrViewSchema; -use spacetimedb_sql_parser::ast::BinOp; +use spacetimedb_sql_parser::ast::{BinOp, SqlExpr, SqlLiteral}; use spacetimedb_sql_parser::{ ast::{sub::SqlSelect, SqlFrom, SqlIdent, SqlJoin}, parser::sub::parse_subscription, }; -use super::{ - errors::{DuplicateName, FunctionCall, TypingError, Unresolved, Unsupported}, - expr::RelExpr, - type_expr, type_proj, type_select, -}; - /// The result of type checking and name resolution pub type TypingResult = core::result::Result; @@ -32,6 +34,10 @@ pub trait SchemaView { fn schema(&self, name: &str) -> Option> { self.table_id(name).and_then(|table_id| self.schema_for_table(table_id)) } + + fn get_or_create_params(&self, params: &ProductValue) -> TypingResult { + Ok(ArgId::SENTINEL) + } } #[derive(Default)] @@ -58,10 +64,119 @@ pub trait TypeChecker { fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult; + fn type_view_params( + schema: &TableOrViewSchema, + vars: &mut Relvars, + args: Option>, + ) -> TypingResult> { + if !schema.is_view() && args.is_some() { + return Err(TypingError::from(TableFunc(schema.table_name.to_string()))); + } + if schema.view_info.as_ref().is_none_or(|a| a.args.is_empty()) && args.as_ref().is_none_or(|a| a.is_empty()) { + return Ok(None); + } + + let params_def: Vec<_> = schema + .view_info + .as_ref() + .map_or(Vec::new(), |info| info.args.clone()) + .into_iter() + .collect(); + + let args = args.unwrap_or_default(); + let len = max(params_def.len(), args.len()); + + let mut expected = Vec::with_capacity(params_def.len()); + let mut inferred = Vec::with_capacity(params_def.len()); + let mut params = Vec::with_capacity(params_def.len()); + let mut failed = false; + + let ty_literal = |lit: &SqlLiteral| match lit { + SqlLiteral::Bool(_) => fmt_algebraic_type(&AlgebraicType::Bool).to_string(), + SqlLiteral::Hex(_) => "Bytes?".to_string(), + SqlLiteral::Num(_) => "Num?".to_string(), + SqlLiteral::Str(_) => fmt_algebraic_type(&AlgebraicType::String).to_string(), + }; + for i in 0..len { + match (params_def.get(i), args.get(i)) { + (Some(param), Some(arg)) => match type_expr(vars, SqlExpr::Lit(arg.clone()), Some(¶m.ty)) { + Ok(Expr::Value(value, inferred_ty)) if inferred_ty == param.ty => { + if let Some(col) = schema.public_columns().get(i) { + if inferred_ty != col.col_type { + failed = true; + + inferred.push(fmt_algebraic_type(&col.col_type).to_string()); + expected.push(fmt_algebraic_type(&inferred_ty).to_string()); + + continue; + }; + expected.push(fmt_algebraic_type(¶m.ty).to_string()); + inferred.push(fmt_algebraic_type(&inferred_ty).to_string()); + } else { + failed = true; + expected.push("?".to_string()); + inferred.push(fmt_algebraic_type(&inferred_ty).to_string()); + continue; + }; + + params.push(value); + } + _ => { + failed = true; + expected.push(fmt_algebraic_type(¶m.ty).to_string()); + inferred.push(ty_literal(arg)); + } + }, + (Some(param), None) => { + failed = true; + expected.push(fmt_algebraic_type(¶m.ty).to_string()); + } + (None, Some(arg)) => { + failed = true; + inferred.push(ty_literal(arg)); + } + (None, None) => {} + } + } + + if failed { + return Err(UnexpectedFunctionType { + expected: expected.join(", "), + inferred: inferred.join(", "), + } + .into()); + } + let params = ProductValue::from_iter(params); + + Ok(Some(params)) + } + + fn type_params( + from: RelExpr, + schema: Arc, + alias: Box, + params: Option, + ) -> RelExpr { + match params { + None => from, + Some(args) => RelExpr::FunCall( + Relvar { + schema, + alias, + delta: None, + }, + args, + ), + } + } + fn type_from(from: SqlFrom, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult { match from { SqlFrom::Expr(SqlIdent(name), SqlIdent(alias)) => { let schema = Self::type_relvar(tx, &name)?; + // Verify we don't have call `SELECT * FROM view` if the view requires parameters... + Self::type_view_params(&schema, vars, None)?; + vars.insert(alias.clone(), schema.clone()); Ok(RelExpr::RelVar(Relvar { schema, @@ -79,15 +194,18 @@ pub trait TypeChecker { }); for SqlJoin { from, on } in joins { - let (SqlIdent(name), SqlIdent(alias)) = from.into_name_alias(); + let (SqlIdent(name), SqlIdent(alias), params) = from.into_name_alias(); + assert!(params.is_none(), "Function calls not allowed in JOINs"); // Check for duplicate aliases if vars.contains_key(&alias) { return Err(DuplicateName(alias.into_string()).into()); } + let schema = Self::type_relvar(tx, &name)?; + let arg = Self::type_view_params(&schema, vars, params)?; + let lhs = Box::new(Self::type_params(join, schema.clone(), alias.clone(), arg)); - let lhs = Box::new(join); let rhs = Relvar { - schema: Self::type_relvar(tx, &name)?, + schema, alias, delta: None, }; @@ -109,8 +227,18 @@ pub trait TypeChecker { Ok(join) } - // TODO: support function calls in FROM clause - SqlFrom::FuncCall(_, _) => Err(FunctionCall.into()), + SqlFrom::FuncCall(func, SqlIdent(alias)) => { + let schema = Self::type_relvar(tx, &func.name.0)?; + let arg = Self::type_view_params(&schema, vars, Some(func.args))?; + vars.insert(alias.clone(), schema.clone()); + let from = RelExpr::RelVar(Relvar { + schema: schema.clone(), + alias: alias.clone(), + delta: None, + }); + + Ok(Self::type_params(from, schema, alias, arg)) + } } } @@ -176,6 +304,7 @@ fn expect_table_type(expr: ProjectList) -> TypingResult { pub mod test_utils { use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, ProductType}; use spacetimedb_primitives::TableId; + use spacetimedb_sats::AlgebraicType; use spacetimedb_schema::{ def::ModuleDef, schema::{Schema, TableOrViewSchema, TableSchema}, @@ -183,12 +312,24 @@ pub mod test_utils { use std::sync::Arc; use super::SchemaView; + pub struct ViewInfo<'a> { + pub(crate) name: &'a str, + pub(crate) columns: &'a [(&'a str, AlgebraicType)], + pub(crate) params: ProductType, + pub(crate) is_anonymous: bool, + } - pub fn build_module_def(types: Vec<(&str, ProductType)>) -> ModuleDef { + pub fn build_module_def(tables: Vec<(&str, ProductType)>, views: Vec) -> ModuleDef { let mut builder = RawModuleDefV9Builder::new(); - for (name, ty) in types { + for (name, ty) in tables { builder.build_table_with_new_type(name, ty, true); } + for view in views { + let product_type = AlgebraicType::from(ProductType::from_iter(view.columns.iter().cloned())); + let type_ref = builder.add_algebraic_type([], view.name, product_type, true); + let return_type = AlgebraicType::array(AlgebraicType::Ref(type_ref)); + builder.add_view(view.name, 0, true, view.is_anonymous, view.params, return_type); + } builder.finish().try_into().expect("failed to generate module def") } @@ -199,23 +340,31 @@ pub mod test_utils { match name { "t" => Some(TableId(0)), "s" => Some(TableId(1)), + "v" => Some(TableId(2)), + "w" => Some(TableId(3)), + "x" => Some(TableId(4)), _ => None, } } fn schema_for_table(&self, table_id: TableId) -> Option> { - match table_id.idx() { - 0 => Some((TableId(0), "t")), - 1 => Some((TableId(1), "s")), - _ => None, - } - .and_then(|(table_id, name)| { - self.0 - .table(name) - .map(|def| Arc::new(TableSchema::from_module_def(&self.0, def, (), table_id))) - .map(TableOrViewSchema::from) - .map(Arc::new) - }) + let (table_id, name) = match table_id.idx() { + 0 => (TableId(0), "t"), + 1 => (TableId(1), "s"), + 2 => (TableId(2), "v"), + 3 => (TableId(3), "w"), + 4 => (TableId(4), "x"), + _ => return None, + }; + self.0 + .table(name) + .map(|def| TableSchema::from_module_def(&self.0, def, (), table_id)) + .or_else(|| { + self.0 + .view(name) + .map(|def| TableSchema::from_view_def_for_datastore(&self.0, def)) + }) + .map(|x| Arc::new(TableOrViewSchema::from(Arc::new(x)))) } fn rls_rules_for_table(&self, _: TableId) -> anyhow::Result>> { @@ -226,50 +375,51 @@ pub mod test_utils { #[cfg(test)] mod tests { + use super::{SchemaView, TypingResult}; use crate::{ check::test_utils::{build_module_def, SchemaViewer}, expr::ProjectName, }; use spacetimedb_lib::{identity::AuthCtx, AlgebraicType, ProductType}; use spacetimedb_schema::def::ModuleDef; - - use super::{SchemaView, TypingResult}; - fn module_def() -> ModuleDef { - build_module_def(vec![ - ( - "t", - ProductType::from([ - ("ts", AlgebraicType::timestamp()), - ("i8", AlgebraicType::I8), - ("u8", AlgebraicType::U8), - ("i16", AlgebraicType::I16), - ("u16", AlgebraicType::U16), - ("i32", AlgebraicType::I32), - ("u32", AlgebraicType::U32), - ("i64", AlgebraicType::I64), - ("u64", AlgebraicType::U64), - ("int", AlgebraicType::U32), - ("f32", AlgebraicType::F32), - ("f64", AlgebraicType::F64), - ("i128", AlgebraicType::I128), - ("u128", AlgebraicType::U128), - ("i256", AlgebraicType::I256), - ("u256", AlgebraicType::U256), - ("str", AlgebraicType::String), - ("arr", AlgebraicType::array(AlgebraicType::String)), - ]), - ), - ( - "s", - ProductType::from([ - ("id", AlgebraicType::identity()), - ("u32", AlgebraicType::U32), - ("arr", AlgebraicType::array(AlgebraicType::String)), - ("bytes", AlgebraicType::bytes()), - ]), - ), - ]) + build_module_def( + vec![ + ( + "t", + ProductType::from([ + ("ts", AlgebraicType::timestamp()), + ("i8", AlgebraicType::I8), + ("u8", AlgebraicType::U8), + ("i16", AlgebraicType::I16), + ("u16", AlgebraicType::U16), + ("i32", AlgebraicType::I32), + ("u32", AlgebraicType::U32), + ("i64", AlgebraicType::I64), + ("u64", AlgebraicType::U64), + ("int", AlgebraicType::U32), + ("f32", AlgebraicType::F32), + ("f64", AlgebraicType::F64), + ("i128", AlgebraicType::I128), + ("u128", AlgebraicType::U128), + ("i256", AlgebraicType::I256), + ("u256", AlgebraicType::U256), + ("str", AlgebraicType::String), + ("arr", AlgebraicType::array(AlgebraicType::String)), + ]), + ), + ( + "s", + ProductType::from([ + ("id", AlgebraicType::identity()), + ("u32", AlgebraicType::U32), + ("arr", AlgebraicType::array(AlgebraicType::String)), + ("bytes", AlgebraicType::bytes()), + ]), + ), + ], + vec![], + ) } /// A wrapper around [super::parse_and_type_sub] that takes a dummy [AuthCtx] diff --git a/crates/expr/src/errors.rs b/crates/expr/src/errors.rs index 9b4894ab1ac..4a91980fde1 100644 --- a/crates/expr/src/errors.rs +++ b/crates/expr/src/errors.rs @@ -114,6 +114,13 @@ impl UnexpectedType { } } +#[derive(Debug, Error)] +#[error("Unexpected function type. Expected: ({expected}) != Inferred: ({inferred})")] +pub struct UnexpectedFunctionType { + pub expected: String, + pub inferred: String, +} + #[derive(Debug, Error)] #[error("Duplicate name `{0}`")] pub struct DuplicateName(pub String); @@ -129,8 +136,8 @@ pub struct DmlOnView { } #[derive(Debug, Error)] -#[error("Function calls are not supported")] -pub struct FunctionCall; +#[error("Table-valued functions are not supported: `{0}`")] +pub struct TableFunc(pub String); #[derive(Error, Debug)] pub enum TypingError { @@ -156,11 +163,15 @@ pub enum TypingError { #[error(transparent)] Unexpected(#[from] UnexpectedType), #[error(transparent)] + UnexpectedFunction(#[from] UnexpectedFunctionType), + #[error(transparent)] Wildcard(#[from] InvalidWildcard), #[error(transparent)] DuplicateName(#[from] DuplicateName), #[error(transparent)] FilterReturnType(#[from] FilterReturnType), #[error(transparent)] - FunctionCall(#[from] FunctionCall), + TableFunc(#[from] TableFunc), + #[error(transparent)] + Other(#[from] anyhow::Error), } diff --git a/crates/expr/src/expr.rs b/crates/expr/src/expr.rs index 451b9f3498c..f6556f34561 100644 --- a/crates/expr/src/expr.rs +++ b/crates/expr/src/expr.rs @@ -1,7 +1,9 @@ use std::{collections::HashSet, sync::Arc}; +use crate::errors::TypingError; use spacetimedb_lib::{query::Delta, AlgebraicType, AlgebraicValue}; -use spacetimedb_primitives::{TableId, ViewId}; +use spacetimedb_primitives::{ArgId, TableId, ViewId}; +use spacetimedb_sats::ProductValue; use spacetimedb_schema::schema::TableOrViewSchema; use spacetimedb_sql_parser::ast::{BinOp, LogOp}; @@ -98,6 +100,10 @@ impl ProjectName { } } +pub trait CallParams { + fn create_or_get_param(&mut self, param: &ProductValue) -> Result; +} + /// A projection is the root of any relational expression. /// This type represents a projection that returns fields. /// @@ -232,6 +238,35 @@ impl ProjectList { Self::Agg(_, _, name, ty) => f(name, ty), } } + + /// Iterate over the function calls in this projection list + pub fn for_each_fun_call( + &mut self, + f: &mut impl FnMut(ProductValue) -> Result, + ) -> Result<(), TypingError> { + match self { + ProjectList::Name(input) => { + for proj in input { + match proj { + ProjectName::None(expr) | ProjectName::Some(expr, _) => { + expr.for_each_fun_call(f)?; + } + } + } + } + ProjectList::List(input, _) => { + for expr in input { + expr.for_each_fun_call(f)?; + } + } + ProjectList::Limit(input, _) => { + input.for_each_fun_call(f)?; + } + ProjectList::Agg(_, _, _, _) => {} + } + + Ok(()) + } } /// A logical relational expression @@ -245,6 +280,8 @@ pub enum RelExpr { LeftDeepJoin(LeftDeepJoin), /// A left deep binary equi-join EqJoin(LeftDeepJoin, FieldProject, FieldProject), + /// A function call + FunCall(Relvar, ProductValue), } /// A table reference @@ -280,7 +317,7 @@ impl RelExpr { | Self::EqJoin(LeftDeepJoin { lhs, .. }, ..) => { lhs.visit(f); } - Self::RelVar(..) => {} + Self::RelVar(..) | Self::FunCall(..) => {} } } @@ -293,14 +330,14 @@ impl RelExpr { | Self::EqJoin(LeftDeepJoin { lhs, .. }, ..) => { lhs.visit_mut(f); } - Self::RelVar(..) => {} + Self::RelVar(..) | Self::FunCall(..) => {} } } /// The number of fields this expression returns pub fn nfields(&self) -> usize { match self { - Self::RelVar(..) => 1, + Self::RelVar(..) | Self::FunCall(..) => 1, Self::LeftDeepJoin(join) | Self::EqJoin(join, ..) => join.lhs.nfields() + 1, Self::Select(input, _) => input.nfields(), } @@ -310,6 +347,7 @@ impl RelExpr { pub fn has_field(&self, field: &str) -> bool { match self { Self::RelVar(Relvar { alias, .. }) => alias.as_ref() == field, + Self::FunCall(Relvar { alias, .. }, ..) => alias.as_ref() == field, Self::LeftDeepJoin(join) | Self::EqJoin(join, ..) => { join.rhs.alias.as_ref() == field || join.lhs.has_field(field) } @@ -360,6 +398,31 @@ impl RelExpr { _ => None, } } + + fn for_each_fun_call( + &mut self, + f: &mut impl FnMut(ProductValue) -> Result, + ) -> Result<(), TypingError> { + // For function calls, we need to filter by the argument id + if let RelExpr::FunCall(relvar, param) = self { + let new_arg_id = f(param.clone())?; + let arg_id_col = relvar.schema.inner().get_column_by_name("arg_id").unwrap().col_pos; + + *self = RelExpr::Select( + Box::new(RelExpr::RelVar(relvar.clone())), + Expr::BinOp( + BinOp::Eq, + Box::new(Expr::Field(FieldProject { + table: relvar.alias.clone(), + field: arg_id_col.idx(), + ty: AlgebraicType::U64, + })), + Box::new(Expr::Value(AlgebraicValue::U64(new_arg_id.0), AlgebraicType::U64)), + ), + ); + } + Ok(()) + } } /// A left deep binary cross product diff --git a/crates/expr/src/lib.rs b/crates/expr/src/lib.rs index 2d9b3cdc5ab..0a322be705f 100644 --- a/crates/expr/src/lib.rs +++ b/crates/expr/src/lib.rs @@ -93,7 +93,22 @@ fn _type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>, d (SqlExpr::Lit(SqlLiteral::Str(_) | SqlLiteral::Num(_) | SqlLiteral::Hex(_)), None) => { Err(Unresolved::Literal.into()) } - (SqlExpr::Lit(SqlLiteral::Str(v) | SqlLiteral::Num(v) | SqlLiteral::Hex(v)), Some(ty)) => Ok(Expr::Value( + (SqlExpr::Lit(SqlLiteral::Num(v)), Some(ty)) => { + if ty.is_integer() || ty.is_float() || ty.is_special() { + Ok(Expr::Value( + parse(&v, ty).map_err(|_| InvalidLiteral::new(v.into_string(), ty))?, + ty.clone(), + )) + } else { + let expected = if v.contains(".") || v.contains("e") || v.contains("E") { + AlgebraicType::F64 + } else { + AlgebraicType::I64 + }; + Err(UnexpectedType::new(&expected, ty).into()) + } + } + (SqlExpr::Lit(SqlLiteral::Str(v) | SqlLiteral::Hex(v)), Some(ty)) => Ok(Expr::Value( parse(&v, ty).map_err(|_| InvalidLiteral::new(v.into_string(), ty))?, ty.clone(), )), diff --git a/crates/expr/src/rls.rs b/crates/expr/src/rls.rs index 75f096ae50c..89cd9e5a2a7 100644 --- a/crates/expr/src/rls.rs +++ b/crates/expr/src/rls.rs @@ -362,7 +362,7 @@ fn alpha_rename(expr: &mut RelExpr, f: &mut impl FnMut(&str) -> Box) { field.table = f(&field.table); } expr.visit_mut(&mut |expr| match expr { - RelExpr::RelVar(rhs) | RelExpr::LeftDeepJoin(LeftDeepJoin { rhs, .. }) => { + RelExpr::RelVar(rhs) | RelExpr::FunCall(rhs, ..) | RelExpr::LeftDeepJoin(LeftDeepJoin { rhs, .. }) => { rename(rhs, f); } RelExpr::EqJoin(LeftDeepJoin { rhs, .. }, a, b) => { @@ -418,7 +418,7 @@ fn alpha_rename(expr: &mut RelExpr, f: &mut impl FnMut(&str) -> Box) { /// ``` fn extend_lhs(expr: RelExpr, with: RelExpr) -> RelExpr { match expr { - RelExpr::RelVar(rhs) => RelExpr::LeftDeepJoin(LeftDeepJoin { + RelExpr::RelVar(rhs) | RelExpr::FunCall(rhs, ..) => RelExpr::LeftDeepJoin(LeftDeepJoin { lhs: Box::new(with), rhs, }), @@ -443,8 +443,8 @@ fn extend_lhs(expr: RelExpr, with: RelExpr) -> RelExpr { fn expand_leaf(expr: RelExpr, table_id: TableId, alias: &str, with: &RelExpr) -> RelExpr { let ok = |relvar: &Relvar| relvar.schema.table_id == table_id && relvar.alias.as_ref() == alias; match expr { - RelExpr::RelVar(relvar, ..) if ok(&relvar) => with.clone(), - RelExpr::RelVar(..) => expr, + RelExpr::RelVar(relvar, ..) | RelExpr::FunCall(relvar, ..) if ok(&relvar) => with.clone(), + RelExpr::RelVar(..) | RelExpr::FunCall(..) => expr, RelExpr::Select(input, expr) => RelExpr::Select(Box::new(expand_leaf(*input, table_id, alias, with)), expr), RelExpr::LeftDeepJoin(join) if ok(&join.rhs) => extend_lhs(with.clone(), *join.lhs), RelExpr::LeftDeepJoin(LeftDeepJoin { lhs, rhs }) => RelExpr::LeftDeepJoin(LeftDeepJoin { @@ -529,20 +529,23 @@ mod tests { } fn module_def() -> ModuleDef { - build_module_def(vec![ - ( - "users", - ProductType::from([("identity", AlgebraicType::identity()), ("id", AlgebraicType::U64)]), - ), - ( - "admins", - ProductType::from([("identity", AlgebraicType::identity()), ("id", AlgebraicType::U64)]), - ), - ( - "player", - ProductType::from([("id", AlgebraicType::U64), ("level_num", AlgebraicType::U64)]), - ), - ]) + build_module_def( + vec![ + ( + "users", + ProductType::from([("identity", AlgebraicType::identity()), ("id", AlgebraicType::U64)]), + ), + ( + "admins", + ProductType::from([("identity", AlgebraicType::identity()), ("id", AlgebraicType::U64)]), + ), + ( + "player", + ProductType::from([("id", AlgebraicType::U64), ("level_num", AlgebraicType::U64)]), + ), + ], + vec![], + ) } /// Parse, type check, and resolve RLS rules diff --git a/crates/expr/src/statement.rs b/crates/expr/src/statement.rs index 46db2b02053..1fc882cf5ca 100644 --- a/crates/expr/src/statement.rs +++ b/crates/expr/src/statement.rs @@ -462,43 +462,61 @@ pub fn compile_sql_stmt<'a>(sql: &'a str, tx: &impl SchemaView, auth: &AuthCtx) #[cfg(test)] mod tests { - use std::sync::Arc; - use super::Statement; use crate::ast::LogOp; + use crate::check::test_utils::ViewInfo; use crate::check::{ test_utils::{build_module_def, SchemaViewer}, Relvars, SchemaView, TypingResult, }; use crate::type_expr; - use spacetimedb::TableId; - use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9Builder; use spacetimedb_lib::{identity::AuthCtx, AlgebraicType, ProductType}; use spacetimedb_schema::def::ModuleDef; - use spacetimedb_schema::schema::{TableOrViewSchema, TableSchema}; use spacetimedb_sql_parser::ast::{SqlExpr, SqlLiteral}; fn module_def() -> ModuleDef { - build_module_def(vec![ - ( - "t", - ProductType::from([ - ("u32", AlgebraicType::U32), - ("f32", AlgebraicType::F32), - ("str", AlgebraicType::String), - ("arr", AlgebraicType::array(AlgebraicType::String)), - ]), - ), - ( - "s", - ProductType::from([ - ("id", AlgebraicType::identity()), - ("u32", AlgebraicType::U32), - ("arr", AlgebraicType::array(AlgebraicType::String)), - ("bytes", AlgebraicType::bytes()), - ]), - ), - ]) + build_module_def( + vec![ + ( + "t", + ProductType::from([ + ("u32", AlgebraicType::U32), + ("f32", AlgebraicType::F32), + ("str", AlgebraicType::String), + ("arr", AlgebraicType::array(AlgebraicType::String)), + ]), + ), + ( + "s", + ProductType::from([ + ("id", AlgebraicType::identity()), + ("u32", AlgebraicType::U32), + ("arr", AlgebraicType::array(AlgebraicType::String)), + ("bytes", AlgebraicType::bytes()), + ]), + ), + ], + vec![ + ViewInfo { + name: "v", + is_anonymous: true, + columns: &[("a", AlgebraicType::String)], + params: ProductType::unit(), + }, + ViewInfo { + name: "w", + is_anonymous: false, + columns: &[("a", AlgebraicType::U64)], + params: ProductType::from([("x", AlgebraicType::U64)]), + }, + ViewInfo { + name: "x", + is_anonymous: false, + columns: &[("a", AlgebraicType::U32), ("b", AlgebraicType::String)], + params: ProductType::from([("p1", AlgebraicType::U32), ("p2", AlgebraicType::String)]), + }, + ], + ) } /// A wrapper around [super::parse_and_type_sql] that takes a dummy [AuthCtx] @@ -563,56 +581,7 @@ mod tests { #[test] fn views() { - struct SchemaViewer { - module_def: ModuleDef, - } - - impl SchemaViewer { - fn schema_for_view(&self, name: &str) -> Option> { - self.module_def - .view(name) - .map(|def| TableSchema::from_view_def_for_datastore(&self.module_def, def)) - .map(Arc::new) - .map(TableOrViewSchema::from) - .map(Arc::new) - } - } - - impl SchemaView for SchemaViewer { - fn table_id(&self, _: &str) -> Option { - None - } - - fn schema(&self, name: &str) -> Option> { - self.schema_for_view(name) - } - - fn schema_for_table(&self, _: TableId) -> Option> { - self.schema_for_view("v") - } - - fn rls_rules_for_table(&self, _: TableId) -> anyhow::Result>> { - Ok(vec![]) - } - } - - fn build_view_def( - builder: &mut RawModuleDefV9Builder, - name: &str, - columns: impl Into, - is_anonymous: bool, - ) { - let product_type = AlgebraicType::from(columns.into()); - let type_ref = builder.add_algebraic_type([], name, product_type, true); - let return_type = AlgebraicType::array(AlgebraicType::Ref(type_ref)); - builder.add_view(name, 0, true, is_anonymous, ProductType::unit(), return_type); - } - - let mut builder = RawModuleDefV9Builder::new(); - build_view_def(&mut builder, "v", [("a", AlgebraicType::String)], true); - let module_def: ModuleDef = builder.finish().try_into().expect("failed to generate module def"); - - let tx = SchemaViewer { module_def }; + let tx = SchemaViewer(module_def()); struct TestCase { sql: &'static str, @@ -628,9 +597,19 @@ mod tests { sql: "select * from v where a = 'hello'", msg: "Column selection on view", }, + TestCase { + sql: "select * from v", + msg: "Function call returning view", + }, + TestCase { + sql: "select * from w(1)", + msg: "Function call returning view with parameters", + }, ] { - let result = parse_and_type_sql(sql, &tx); - assert!(result.is_ok(), "{msg}"); + let result = parse_and_type_sql(sql, &tx).inspect_err(|e| { + panic!("Expected OK for `{sql}` but got error: {e}"); + }); + assert!(result.is_ok(), "{msg}: {sql}"); } for TestCase { sql, msg } in [ @@ -646,9 +625,59 @@ mod tests { sql: "select arg_id from v", msg: "`v` does not have a column named `arg_id`", }, + TestCase { + sql: "select * from v(1)", + msg: "`v` does not take parameters", + }, ] { let result = parse_and_type_sql(sql, &tx); assert!(result.is_err(), "{msg}"); } } + + #[test] + fn params_validation() { + let tx = SchemaViewer(module_def()); + + struct TestCase { + sql: &'static str, + msg: &'static str, + } + + for TestCase { sql, msg } in [ + TestCase { + sql: "select * from x(1, 'hello')", + msg: "Correct parameters", + }, + TestCase { + sql: "select * from x()", + msg: "Unsupported call to table-valued function with empty params. Use `select * from table_function` syntax instead: x", + }, + TestCase { + sql: "select * from x", + msg: "Unexpected function type. Expected: (U32, String) != Inferred: ()", + }, + TestCase { + sql: "select * from x('hello', 1)", + msg: "Unexpected function type. Expected: (U32, String) != Inferred: (String, Num?)", + }, + TestCase { + sql: "select * from x(1)", + msg: "Unexpected function type. Expected: (U32, String) != Inferred: (U32)", + }, + TestCase { + sql: "select * from x(1, 'hello', 2)", + msg: "Unexpected function type. Expected: (U32, String) != Inferred: (U32, String, Num?)", + }, + ] { + let result = parse_and_type_sql(sql, &tx); + if msg == "Correct parameters" { + assert!(result.is_ok(), "{msg}: {sql}"); + } else if let Err(err) = &result { + assert_eq!(err.to_string(), msg, "{sql}"); + } else { + panic!("Expected error for SQL `{sql}` but got OK"); + } + } + } } diff --git a/crates/physical-plan/src/compile.rs b/crates/physical-plan/src/compile.rs index 0997fa0b968..c7737e7d1f7 100644 --- a/crates/physical-plan/src/compile.rs +++ b/crates/physical-plan/src/compile.rs @@ -63,7 +63,7 @@ fn compile_field_project(var: &mut impl VarLabel, expr: FieldProject) -> TupleFi fn compile_rel_expr(var: &mut impl VarLabel, ast: RelExpr) -> PhysicalPlan { match ast { - RelExpr::RelVar(Relvar { schema, alias, delta }) => { + RelExpr::RelVar(Relvar { schema, alias, delta }) | RelExpr::FunCall(Relvar { schema, alias, delta }, ..) => { let label = var.label(alias.as_ref()); let schema = schema.inner(); PhysicalPlan::TableScan( diff --git a/crates/physical-plan/src/plan.rs b/crates/physical-plan/src/plan.rs index 81933829ef9..597ab6dd6eb 100644 --- a/crates/physical-plan/src/plan.rs +++ b/crates/physical-plan/src/plan.rs @@ -1412,7 +1412,10 @@ mod tests { identity::AuthCtx, AlgebraicType, AlgebraicValue, }; - use spacetimedb_primitives::{ColId, ColList, ColSet, TableId}; + use spacetimedb_primitives::{ColId, ColList, ColSet, TableId, ViewId}; + use spacetimedb_schema::def::ViewParamDefSimple; + use spacetimedb_schema::identifier::Identifier; + use spacetimedb_schema::schema::ViewDefInfo; use spacetimedb_schema::{ def::{BTreeAlgorithm, ConstraintData, IndexAlgorithm, UniqueConstraintData}, schema::{ColumnSchema, ConstraintSchema, IndexSchema, TableOrViewSchema, TableSchema}, @@ -1447,18 +1450,36 @@ mod tests { } } - fn schema( + fn schema_with_params( table_id: TableId, table_name: &str, columns: &[(&str, AlgebraicType)], indexes: &[&[usize]], unique: &[&[usize]], primary_key: Option, + params: Option<&[(&str, AlgebraicType)]>, ) -> TableOrViewSchema { + let mut columns = columns.to_vec(); + // Need to add the hidden columns for views, see `TableSchema::from_view_def_for_datastore` + if params.is_some() { + columns.insert(0, ("arg_id", AlgebraicType::U64)); + } + let view = params.map(|params| ViewDefInfo { + view_id: ViewId::SENTINEL, + args: params + .iter() + .map(|(name, ty)| ViewParamDefSimple { + name: Identifier::new((*name).into()).unwrap(), + ty: ty.clone(), + }) + .collect(), + is_anonymous: true, + }); + TableOrViewSchema::from(Arc::new(TableSchema::new( table_id, table_name.to_owned().into_boxed_str(), - None, + view, columns .iter() .enumerate() @@ -1501,6 +1522,17 @@ mod tests { ))) } + fn schema( + table_id: TableId, + table_name: &str, + columns: &[(&str, AlgebraicType)], + indexes: &[&[usize]], + unique: &[&[usize]], + primary_key: Option, + ) -> TableOrViewSchema { + schema_with_params(table_id, table_name, columns, indexes, unique, primary_key, None) + } + /// A wrapper around [spacetimedb_expr::check::parse_and_type_sub] that takes a dummy [AuthCtx] fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult { spacetimedb_expr::check::parse_and_type_sub(sql, tx, &AuthCtx::for_testing()).map(|(plan, _)| plan) @@ -2265,4 +2297,67 @@ mod tests { assert!(plan.plan_iter().any(|plan| plan.has_filter())); assert!(plan.plan_iter().any(|plan| plan.has_table_scan(None))); } + + // Verify the view parameters are converted to filters in the physical plan + #[test] + fn view_params() { + let t_id = TableId(1); + let v_id = TableId(2); + + let t = Arc::new(schema( + t_id, + "t", + &[("id", AlgebraicType::U64), ("x", AlgebraicType::U32)], + &[&[0]], + &[&[0]], + Some(0), + )); + let v = Arc::new(schema_with_params( + v_id, + "v", + &[("id", AlgebraicType::U64), ("x", AlgebraicType::U32)], + &[], + &[], + None, + Some(&[("param_id", AlgebraicType::U64)]), + )); + + let db = SchemaViewer { + schemas: vec![t.clone(), v.clone()], + }; + + let sql = "select * from v(0)"; + + let auth = AuthCtx::for_testing(); + let lp = parse_and_type_sub(sql, &db).unwrap(); + dbg!(&lp); + let pp = compile_select(lp).optimize(&auth).unwrap(); + dbg!(&pp); + match pp { + ProjectPlan::None(PhysicalPlan::Filter(input, PhysicalExpr::BinOp(BinOp::Eq, field, value))) => { + assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 1, .. }))); + assert!(matches!(*value, PhysicalExpr::Value(AlgebraicValue::U64(0)))); + + match *input { + PhysicalPlan::TableScan(TableScan { schema, .. }, _) => { + assert_eq!(schema.table_id, v_id); + } + plan => panic!("unexpected plan: {plan:#?}"), + } + } + proj => panic!("unexpected project: {proj:#?}"), + }; + + let sql = "select * from v(0) as x JOIN t ON x.id = t.id"; + let lp = parse_and_type_sub(sql, &db).unwrap(); + let pp = compile_select(lp).optimize(&auth).unwrap(); + + match pp { + ProjectPlan::None(PhysicalPlan::Filter(_, PhysicalExpr::BinOp(BinOp::Eq, field, value))) => { + assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 1, .. }))); + assert!(matches!(*value, PhysicalExpr::Value(AlgebraicValue::U64(0)))); + } + proj => panic!("unexpected project: {proj:#?}"), + }; + } } diff --git a/crates/query/src/lib.rs b/crates/query/src/lib.rs index d75f2516a2c..96efff60ddb 100644 --- a/crates/query/src/lib.rs +++ b/crates/query/src/lib.rs @@ -4,6 +4,8 @@ use spacetimedb_execution::{ pipelined::ProjectListExecutor, Datastore, DeltaStore, }; +use spacetimedb_expr::errors::TypingError; +use spacetimedb_expr::expr::CallParams; use spacetimedb_expr::{ check::{parse_and_type_sub, SchemaView}, expr::ProjectList, @@ -15,13 +17,27 @@ use spacetimedb_physical_plan::{ compile::{compile_dml_plan, compile_select, compile_select_list}, plan::{ProjectListPlan, ProjectPlan}, }; -use spacetimedb_primitives::TableId; +use spacetimedb_primitives::{ArgId, TableId}; +use std::collections::HashMap; /// DIRTY HACK ALERT: Maximum allowed length, in UTF-8 bytes, of SQL queries. /// Any query longer than this will be rejected. /// This prevents a stack overflow when compiling queries with deeply-nested `AND` and `OR` conditions. const MAX_SQL_LENGTH: usize = 50_000; +pub trait CallParamsExt { + fn get_arg(&self, params: &ProductValue) -> Result; +} + +pub struct MockCallParams { + params: HashMap, +} +impl CallParamsExt for MockCallParams { + fn get_arg(&self, params: &ProductValue) -> Result { + todo!() + } +} + pub fn compile_subscription( sql: &str, tx: &impl SchemaView, diff --git a/crates/schema/src/def.rs b/crates/schema/src/def.rs index 2c0a97d9ef4..c30a0c31aec 100644 --- a/crates/schema/src/def.rs +++ b/crates/schema/src/def.rs @@ -833,6 +833,16 @@ pub struct ViewParamDef { pub view_name: Identifier, } +/// A struct representing a validated view parameter +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct ViewParamDefSimple { + /// The name of the parameter. + pub name: Identifier, + + /// The type of this parameter. + pub ty: AlgebraicType, +} + /// A constraint definition attached to a table. #[derive(Debug, Clone, Eq, PartialEq)] pub struct ConstraintDef { diff --git a/crates/schema/src/schema.rs b/crates/schema/src/schema.rs index 51826292454..fca8e634e22 100644 --- a/crates/schema/src/schema.rs +++ b/crates/schema/src/schema.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use crate::def::{ ColumnDef, ConstraintData, ConstraintDef, IndexAlgorithm, IndexDef, ModuleDef, ModuleDefLookup, ScheduleDef, - SequenceDef, TableDef, UniqueConstraintData, ViewColumnDef, ViewDef, + SequenceDef, TableDef, UniqueConstraintData, ViewColumnDef, ViewDef, ViewParamDefSimple, }; use crate::identifier::Identifier; @@ -50,16 +50,16 @@ pub trait Schema: Sized { fn check_compatible(&self, module_def: &ModuleDef, def: &Self::Def) -> Result<(), anyhow::Error>; } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct ViewDefInfo { pub view_id: ViewId, - pub has_args: bool, + pub args: Vec, pub is_anonymous: bool, } impl ViewDefInfo { pub fn num_private_cols(&self) -> usize { - (if self.is_anonymous { 0 } else { 1 }) + (if self.has_args { 1 } else { 0 }) + (if self.is_anonymous { 0 } else { 1 }) + (if !self.args.is_empty() { 1 } else { 0 }) } } @@ -77,7 +77,7 @@ impl From> for TableOrViewSchema { fn from(inner: Arc) -> Self { Self { table_id: inner.table_id, - view_info: inner.view_info, + view_info: inner.view_info.clone(), table_name: inner.table_name.clone(), table_access: inner.table_access, inner, @@ -110,31 +110,22 @@ impl TableOrViewSchema { /// For views in particular it will not include the internal `sender` and `arg_id` columns. /// Hence columns in this list should be looked up by their [`ColId`] - not their position. pub fn public_columns(&self) -> &[ColumnSchema] { - match self.view_info { - Some(ViewDefInfo { - has_args: true, - is_anonymous: false, - .. - }) => &self.inner.columns[2..], - Some(ViewDefInfo { - has_args: true, - is_anonymous: true, - .. - }) => &self.inner.columns[1..], - Some(ViewDefInfo { - has_args: false, - is_anonymous: false, - .. - }) => &self.inner.columns[1..], - Some(ViewDefInfo { - has_args: false, - is_anonymous: true, - .. - }) - | None => &self.inner.columns, - } - } + let skip = match &self.view_info { + Some(info) => { + let has_args = !info.args.is_empty(); + + match (has_args, info.is_anonymous) { + (true, false) => 2, // sender + arg_id + (true, true) => 1, // arg_id + (false, false) => 1, // sender + (false, true) => 0, // none + } + } + None => 0, + }; + &self.inner.columns[skip..] + } /// Check if the `col_name` exist on this [`TableOrViewSchema`] pub fn get_column_by_name(&self, col_name: &str) -> Option<&ColumnSchema> { self.public_columns().iter().find(|x| &*x.col_name == col_name) @@ -742,7 +733,13 @@ impl TableSchema { let view_info = ViewDefInfo { view_id: ViewId::SENTINEL, - has_args: !param_columns.is_empty(), + args: param_columns + .iter() + .map(|arg| ViewParamDefSimple { + name: arg.name.clone(), + ty: arg.ty.clone(), + }) + .collect(), is_anonymous: *is_anonymous, }; @@ -864,7 +861,13 @@ impl TableSchema { let view_info = ViewDefInfo { view_id: ViewId::SENTINEL, - has_args: !param_columns.is_empty(), + args: param_columns + .iter() + .map(|arg| ViewParamDefSimple { + name: arg.name.clone(), + ty: arg.ty.clone(), + }) + .collect(), is_anonymous: *is_anonymous, }; diff --git a/crates/sql-parser/src/ast/mod.rs b/crates/sql-parser/src/ast/mod.rs index 671c6c342b8..a6ffa327f80 100644 --- a/crates/sql-parser/src/ast/mod.rs +++ b/crates/sql-parser/src/ast/mod.rs @@ -31,10 +31,10 @@ pub enum SqlFromSource { } impl SqlFromSource { - pub fn into_name_alias(self) -> (SqlIdent, SqlIdent) { + pub fn into_name_alias(self) -> (SqlIdent, SqlIdent, Option>) { match self { - Self::Expr(name, alias) => (name, alias), - Self::FuncCall(func, alias) => (func.name, alias), + Self::Expr(name, alias) => (name, alias, None), + Self::FuncCall(func, alias) => (func.name, alias, Some(func.args)), } } } @@ -213,7 +213,7 @@ impl From for SqlIdent { } /// A SQL constant expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum SqlLiteral { /// A boolean constant Bool(bool), diff --git a/modules/module-test/src/lib.rs b/modules/module-test/src/lib.rs index a5b59720cfd..c6b25aad5cf 100644 --- a/modules/module-test/src/lib.rs +++ b/modules/module-test/src/lib.rs @@ -209,6 +209,11 @@ fn my_player(ctx: &ViewContext) -> Option { ctx.db.player().identity().find(ctx.sender) } +/*#[spacetimedb::view(name = players_in_chunk, public)] +fn players_in_chunk(ctx: &AnonymousViewContext, chunk_index: u32) -> Query { + ctx.db.player().chunk_index().filter(chunk_index) +}*/ + // ───────────────────────────────────────────────────────────────────────────── // REDUCERS // ───────────────────────────────────────────────────────────────────────────── From 4289cd8517282e0786dcc4768928111dfd5718bc Mon Sep 17 00:00:00 2001 From: Mario Alejandro Montoya Cortes Date: Thu, 25 Dec 2025 16:43:09 -0500 Subject: [PATCH 3/3] Making tx mut so we can auto-create params on compilation --- crates/bench/benches/subscription.rs | 24 ++-- crates/core/src/estimation.rs | 10 +- crates/core/src/host/module_host.rs | 6 +- .../src/host/wasm_common/module_host_actor.rs | 4 +- crates/core/src/sql/ast.rs | 23 +++- crates/core/src/sql/compiler.rs | 82 ++++++------ crates/core/src/sql/execute.rs | 57 ++++++--- crates/core/src/sql/parser.rs | 2 +- .../subscription/module_subscription_actor.rs | 14 +- .../module_subscription_manager.rs | 4 +- crates/core/src/subscription/query.rs | 77 +++++++---- crates/core/src/subscription/subscription.rs | 18 +-- crates/core/src/util/slow.rs | 12 +- crates/expr/src/check.rs | 120 ++++++++++++------ crates/expr/src/errors.rs | 2 + crates/expr/src/expr.rs | 54 -------- crates/expr/src/rls.rs | 69 +++++----- crates/expr/src/statement.rs | 28 ++-- crates/physical-plan/src/plan.rs | 72 ++++++----- crates/query/src/lib.rs | 23 +--- crates/subscription/src/lib.rs | 2 +- 21 files changed, 379 insertions(+), 324 deletions(-) diff --git a/crates/bench/benches/subscription.rs b/crates/bench/benches/subscription.rs index 065916f78b2..9f291f0fb92 100644 --- a/crates/bench/benches/subscription.rs +++ b/crates/bench/benches/subscription.rs @@ -126,10 +126,10 @@ fn eval(c: &mut Criterion) { // A benchmark runner for the new query engine let bench_query = |c: &mut Criterion, name, sql| { c.bench_function(name, |b| { - let tx = raw.db.begin_tx(Workload::Subscribe); + let mut tx = raw.db.begin_tx(Workload::Subscribe); let auth = AuthCtx::for_testing(); - let schema_viewer = &SchemaViewer::new(&tx, &auth); - let (plans, table_id, table_name, _) = compile_subscription(sql, schema_viewer, &auth).unwrap(); + let mut schema_viewer = SchemaViewer::new(&mut tx, &auth); + let (plans, table_id, table_name, _) = compile_subscription(sql, &mut schema_viewer, &auth).unwrap(); let plans = plans .into_iter() .map(|plan| plan.optimize(&auth).unwrap()) @@ -155,8 +155,8 @@ fn eval(c: &mut Criterion) { let bench_eval = |c: &mut Criterion, name, sql| { c.bench_function(name, |b| { - let tx = raw.db.begin_tx(Workload::Update); - let query = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &tx, sql).unwrap(); + let mut tx = raw.db.begin_tx(Workload::Update); + let query = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &mut tx, sql).unwrap(); let query: ExecutionSet = query.into(); b.iter(|| { @@ -207,11 +207,11 @@ fn eval(c: &mut Criterion) { // A passthru executed independently of the database. let select_lhs = "select * from footprint"; let select_rhs = "select * from location"; - let tx = &raw.db.begin_tx(Workload::Update); - let query_lhs = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), tx, select_lhs).unwrap(); - let query_rhs = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), tx, select_rhs).unwrap(); + let mut tx = raw.db.begin_tx(Workload::Update); + let query_lhs = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &mut tx, select_lhs).unwrap(); + let query_rhs = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &mut tx, select_rhs).unwrap(); let query = ExecutionSet::from_iter(query_lhs.into_iter().chain(query_rhs)); - let tx = &tx.into(); + let tx = &(&mut tx).into(); b.iter(|| drop(black_box(query.eval_incr_for_test(&raw.db, tx, &update, None)))) }); @@ -226,10 +226,10 @@ fn eval(c: &mut Criterion) { from footprint join location on footprint.entity_id = location.entity_id \ where location.chunk_index = {chunk_index}" ); - let tx = &raw.db.begin_tx(Workload::Update); - let query = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), tx, &join).unwrap(); + let mut tx = raw.db.begin_tx(Workload::Update); + let query = compile_read_only_queryset(&raw.db, &AuthCtx::for_testing(), &mut tx, &join).unwrap(); let query: ExecutionSet = query.into(); - let tx = &tx.into(); + let tx = &(&mut tx).into(); b.iter(|| drop(black_box(query.eval_incr_for_test(&raw.db, tx, &update, None)))); }); diff --git a/crates/core/src/estimation.rs b/crates/core/src/estimation.rs index 24130a145f8..5e046b06c38 100644 --- a/crates/core/src/estimation.rs +++ b/crates/core/src/estimation.rs @@ -181,8 +181,8 @@ mod tests { } fn num_rows_for(db: &RelationalDB, sql: &str) -> u64 { - let tx = begin_tx(db); - match &*compile_sql(db, &AuthCtx::for_testing(), &tx, sql).expect("Failed to compile sql") { + let mut tx = begin_tx(db); + match &*compile_sql(db, &AuthCtx::for_testing(), &mut tx, sql).expect("Failed to compile sql") { [CrudExpr::Query(expr)] => num_rows(&tx, expr), exprs => panic!("unexpected result from compilation: {exprs:#?}"), } @@ -191,10 +191,10 @@ mod tests { /// Using the new query plan fn new_row_estimate(db: &RelationalDB, sql: &str) -> u64 { let auth = AuthCtx::for_testing(); - let tx = begin_tx(db); - let tx = SchemaViewer::new(&tx, &auth); + let mut tx = begin_tx(db); + let mut tx = SchemaViewer::new(&mut tx, &auth); - compile_subscription(sql, &tx, &auth) + compile_subscription(sql, &mut tx, &auth) .map(|(plans, ..)| plans) .expect("failed to compile sql query") .into_iter() diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index aaafeb337f1..3f77ba4e6fe 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -1869,7 +1869,7 @@ impl ModuleHost { let metrics = self .on_module_thread("one_off_query", move || { let (tx_offset_sender, tx_offset_receiver) = oneshot::channel(); - let tx = scopeguard::guard(db.begin_tx(Workload::Sql), |tx| { + let mut tx = scopeguard::guard(db.begin_tx(Workload::Sql), |tx| { let (tx_offset, tx_metrics, reducer) = db.release_tx(tx); let _ = tx_offset_sender.send(tx_offset); db.report_read_tx_metrics(reducer, tx_metrics); @@ -1878,7 +1878,7 @@ impl ModuleHost { // We wrap the actual query in a closure so we can use ? to handle errors without making // the entire transaction abort with an error. let result: Result<(OneOffTable, ExecutionMetrics), anyhow::Error> = (|| { - let tx = SchemaViewer::new(&*tx, &auth); + let mut tx = SchemaViewer::new(&mut *tx, &auth); let ( // A query may compile down to several plans. @@ -1888,7 +1888,7 @@ impl ModuleHost { _, table_name, _, - ) = compile_subscription(&query, &tx, &auth)?; + ) = compile_subscription(&query, &mut tx, &auth)?; // Optimize each fragment let optimized = plans diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index 13cdd04fd1b..543c07c5a6d 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -1126,10 +1126,10 @@ impl InstanceCommon { // Views bypass RLS, since views should enforce their own access control procedurally. let auth = AuthCtx::for_current(self.info.database_identity); - let schema_view = SchemaViewer::new(&*tx, &auth); + let mut schema_view = SchemaViewer::new(&mut *tx, &auth); // Compile to subscription plans. - let (plans, has_params) = SubscriptionPlan::compile(the_query, &schema_view, &auth)?; + let (plans, has_params) = SubscriptionPlan::compile(the_query, &mut schema_view, &auth)?; ensure!( !has_params, "parameterized SQL is not supported for view materialization yet" diff --git a/crates/core/src/sql/ast.rs b/crates/core/src/sql/ast.rs index a11775331ce..51f73c8f4f8 100644 --- a/crates/core/src/sql/ast.rs +++ b/crates/core/src/sql/ast.rs @@ -5,6 +5,7 @@ use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap}; use spacetimedb_datastore::locking_tx_datastore::state_view::StateView; use spacetimedb_datastore::system_tables::{StRowLevelSecurityFields, ST_ROW_LEVEL_SECURITY_ID}; use spacetimedb_expr::check::{SchemaView, TypingResult}; +use spacetimedb_expr::errors::TypingError; use spacetimedb_expr::statement::compile_sql_stmt; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_primitives::{ArgId, ColId, TableId}; @@ -22,7 +23,7 @@ use sqlparser::ast::{ }; use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; -use std::ops::Deref; +use std::ops::{Deref, DerefMut}; use std::sync::Arc; /// Simplify to detect features of the syntax we don't support yet @@ -477,7 +478,7 @@ fn compile_where(table: &From, filter: Option) -> Result { - pub(crate) tx: &'a T, + tx: &'a mut T, auth: &'a AuthCtx, } @@ -489,6 +490,12 @@ impl Deref for SchemaViewer<'_, T> { } } +impl DerefMut for SchemaViewer<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.tx + } +} + impl SchemaView for SchemaViewer<'_, T> { fn table_id(&self, name: &str) -> Option { // Get the schema from the in-memory state instead of fetching from the database for speed @@ -536,10 +543,15 @@ impl SchemaView for SchemaViewer<'_, T> { }) .collect::>() } + + fn get_or_create_params(&mut self, _params: ProductValue) -> TypingResult { + // Caller should have used `SchemaViewerMut` on crate `core` + Err(TypingError::ParamsReadOnly) + } } impl<'a, T> SchemaViewer<'a, T> { - pub fn new(tx: &'a T, auth: &'a AuthCtx) -> Self { + pub fn new(tx: &'a mut T, auth: &'a AuthCtx) -> Self { Self { tx, auth } } } @@ -1000,13 +1012,12 @@ fn compile_statement( pub(crate) fn compile_to_ast( db: &RelationalDB, auth: &AuthCtx, - tx: &T, + tx: &mut T, sql_text: &str, ) -> Result, DBError> { // NOTE: The following ensures compliance with the 1.0 sql api. // Come 1.0, it will have replaced the current compilation stack. - compile_sql_stmt(sql_text, &SchemaViewer::new(tx, auth), auth)?; - + compile_sql_stmt(sql_text, &mut SchemaViewer::new(tx, auth), auth)?; let dialect = PostgreSqlDialect {}; let ast = Parser::parse_sql(&dialect, sql_text).map_err(|error| DBError::SqlParser { sql: sql_text.to_string(), diff --git a/crates/core/src/sql/compiler.rs b/crates/core/src/sql/compiler.rs index eb17de1e8a4..97ba4a507e6 100644 --- a/crates/core/src/sql/compiler.rs +++ b/crates/core/src/sql/compiler.rs @@ -23,7 +23,7 @@ const MAX_SQL_LENGTH: usize = 50_000; pub fn compile_sql( db: &RelationalDB, auth: &AuthCtx, - tx: &T, + tx: &mut T, sql_text: &str, ) -> Result, DBError> { if sql_text.len() > MAX_SQL_LENGTH { @@ -266,7 +266,7 @@ mod tests { fn compile_sql( db: &RelationalDB, - tx: &T, + tx: &mut T, sql: &str, ) -> Result, DBError> { super::compile_sql(db, &AuthCtx::for_testing(), tx, sql) @@ -281,10 +281,10 @@ mod tests { let indexes = &[]; db.create_table_for_test("test", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Compile query let sql = "select * from test where a = 1"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(1, query.len()); @@ -303,10 +303,10 @@ mod tests { &[1.into(), 0.into()], )?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should work with any qualified field. let sql = "select * from test where a = 1 and b <> 3"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(2, query.len()); @@ -324,10 +324,10 @@ mod tests { let indexes = &[0.into()]; db.create_table_for_test("test", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); //Compile query let sql = "select * from test where a = 1"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(1, query.len()); @@ -377,11 +377,11 @@ mod tests { let rows = run_for_testing(&db, sql)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); let CrudExpr::Query(QueryExpr { source: _, query: mut ops, - }) = compile_sql(&db, &tx, sql)?.remove(0) + }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; @@ -407,11 +407,11 @@ mod tests { let indexes = &[1.into()]; db.create_table_for_test("test", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Note, order does not matter. // The sargable predicate occurs last, but we can still generate an index scan. let sql = "select * from test where a = 1 and b = 2"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(2, query.len()); @@ -429,11 +429,11 @@ mod tests { let indexes = &[1.into()]; db.create_table_for_test("test", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Note, order does not matter. // The sargable predicate occurs first and we can generate an index scan. let sql = "select * from test where b = 2 and a = 1"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(2, query.len()); @@ -455,9 +455,9 @@ mod tests { ]; db.create_table_for_test_multi_column("test", schema, col_list![0, 1])?; - let tx = begin_mut_tx(&db); + let mut tx = begin_mut_tx(&db); let sql = "select * from test where b = 2 and a = 1"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(1, query.len()); @@ -474,10 +474,10 @@ mod tests { let indexes = &[0.into(), 1.into()]; db.create_table_for_test("test", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Compile query let sql = "select * from test where a = 1 or b = 2"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(1, query.len()); @@ -495,10 +495,10 @@ mod tests { let indexes = &[1.into()]; db.create_table_for_test("test", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Compile query let sql = "select * from test where b > 2"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(1, query.len()); @@ -516,10 +516,10 @@ mod tests { let indexes = &[1.into()]; db.create_table_for_test("test", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Compile query let sql = "select * from test where b > 2 and b < 5"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(1, query.len()); @@ -542,11 +542,11 @@ mod tests { let indexes = &[0.into(), 1.into()]; db.create_table_for_test("test", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Note, order matters - the equality condition occurs first which // means an index scan will be generated rather than the range condition. let sql = "select * from test where a = 3 and b > 2 and b < 5"; - let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else { + let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &mut tx, sql)?.remove(0) else { panic!("Expected QueryExpr"); }; assert_eq!(2, query.len()); @@ -569,10 +569,10 @@ mod tests { let indexes = &[]; let rhs_id = db.create_table_for_test("rhs", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should push sargable equality condition below join let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where lhs.a = 3"; - let exp = compile_sql(&db, &tx, sql)?.remove(0); + let exp = compile_sql(&db, &mut tx, sql)?.remove(0); let CrudExpr::Query(QueryExpr { source: source_lhs, @@ -621,10 +621,10 @@ mod tests { let schema = &[("b", AlgebraicType::U64), ("c", AlgebraicType::U64)]; let rhs_id = db.create_table_for_test("rhs", schema, &[])?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should push equality condition below join let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where lhs.a = 3"; - let exp = compile_sql(&db, &tx, sql)?.remove(0); + let exp = compile_sql(&db, &mut tx, sql)?.remove(0); let CrudExpr::Query(QueryExpr { source: source_lhs, @@ -678,10 +678,10 @@ mod tests { let schema = &[("b", AlgebraicType::U64), ("c", AlgebraicType::U64)]; let rhs_id = db.create_table_for_test("rhs", schema, &[])?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should push equality condition below join let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c = 3"; - let exp = compile_sql(&db, &tx, sql)?.remove(0); + let exp = compile_sql(&db, &mut tx, sql)?.remove(0); let CrudExpr::Query(QueryExpr { source: source_lhs, @@ -736,11 +736,11 @@ mod tests { let indexes = &[1.into()]; let rhs_id = db.create_table_for_test("rhs", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should push the sargable equality condition into the join's left arg. // Should push the sargable range condition into the join's right arg. let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where lhs.a = 3 and rhs.c < 4"; - let exp = compile_sql(&db, &tx, sql)?.remove(0); + let exp = compile_sql(&db, &mut tx, sql)?.remove(0); let CrudExpr::Query(QueryExpr { source: source_lhs, @@ -807,11 +807,11 @@ mod tests { let indexes = &[0.into(), 1.into()]; let rhs_id = db.create_table_for_test("rhs", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should generate an index join since there is an index on `lhs.b`. // Should push the sargable range condition into the index join's probe side. let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3"; - let exp = compile_sql(&db, &tx, sql)?.remove(0); + let exp = compile_sql(&db, &mut tx, sql)?.remove(0); let CrudExpr::Query(QueryExpr { source: SourceExpr::DbTable(DbTable { table_id, .. }), @@ -889,11 +889,11 @@ mod tests { let indexes = col_list![0, 1]; let rhs_id = db.create_table_for_test_multi_column("rhs", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should generate an index join since there is an index on `lhs.b`. // Should push the sargable range condition into the index join's probe side. let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c = 2 and rhs.b = 4 and rhs.d = 3"; - let exp = compile_sql(&db, &tx, sql)?.remove(0); + let exp = compile_sql(&db, &mut tx, sql)?.remove(0); let CrudExpr::Query(QueryExpr { source: SourceExpr::DbTable(DbTable { table_id, .. }), @@ -953,7 +953,7 @@ mod tests { let db = TestDB::durable()?; db.create_table_for_test("A", &[("x", AlgebraicType::U64)], &[])?; db.create_table_for_test("B", &[("y", AlgebraicType::U64)], &[])?; - assert!(compile_sql(&db, &begin_tx(&db), "select B.* from B join A on B.y = A.x").is_ok()); + assert!(compile_sql(&db, &mut begin_tx(&db), "select B.* from B join A on B.y = A.x").is_ok()); Ok(()) } @@ -970,27 +970,27 @@ mod tests { // TODO: Type check other operations deferred for the new query engine. assert!( - compile_sql(&db, &begin_tx(&db), sql).is_err(), + compile_sql(&db, &mut begin_tx(&db), sql).is_err(), // Err("SqlError: Type Mismatch: `PlayerState.entity_id: U64` != `String(\"161853\"): String`, executing: `SELECT * FROM PlayerState WHERE entity_id = '161853'`".into()) ); // Check we can still compile the query if we remove the type mismatch and have multiple logical operations. let sql = "SELECT * FROM PlayerState WHERE entity_id = 1 AND entity_id = 2 AND entity_id = 3 OR entity_id = 4 OR entity_id = 5"; - assert!(compile_sql(&db, &begin_tx(&db), sql).is_ok()); + assert!(compile_sql(&db, &mut begin_tx(&db), sql).is_ok()); // Now verify when we have a type mismatch in the middle of the logical operations. let sql = "SELECT * FROM PlayerState WHERE entity_id = 1 AND entity_id"; assert!( - compile_sql(&db, &begin_tx(&db), sql).is_err(), + compile_sql(&db, &mut begin_tx(&db), sql).is_err(), // Err("SqlError: Type Mismatch: `PlayerState.entity_id: U64 == U64(1): U64` and `PlayerState.entity_id: U64`, both sides must be an `Bool` expression, executing: `SELECT * FROM PlayerState WHERE entity_id = 1 AND entity_id`".into()) ); // Verify that all operands of `AND` must be `Bool`. let sql = "SELECT * FROM PlayerState WHERE entity_id AND entity_id"; assert!( - compile_sql(&db, &begin_tx(&db), sql).is_err(), + compile_sql(&db, &mut begin_tx(&db), sql).is_err(), // Err("SqlError: Type Mismatch: `PlayerState.entity_id: U64` and `PlayerState.entity_id: U64`, both sides must be an `Bool` expression, executing: `SELECT * FROM PlayerState WHERE entity_id AND entity_id`".into()) ); Ok(()) diff --git a/crates/core/src/sql/execute.rs b/crates/core/src/sql/execute.rs index 2e2e2337509..6f30872e851 100644 --- a/crates/core/src/sql/execute.rs +++ b/crates/core/src/sql/execute.rs @@ -20,9 +20,8 @@ use anyhow::anyhow; use spacetimedb_datastore::execution_context::Workload; use spacetimedb_datastore::locking_tx_datastore::state_view::StateView; use spacetimedb_datastore::traits::IsolationLevel; -use spacetimedb_expr::check::SchemaView; +use spacetimedb_expr::check::{SchemaView, TypingResult}; use spacetimedb_expr::errors::TypingError; -use spacetimedb_expr::expr::CallParams; use spacetimedb_expr::statement::Statement; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::metrics::ExecutionMetrics; @@ -191,15 +190,27 @@ pub struct SqlResult { pub metrics: ExecutionMetrics, } -struct DbParams<'a> { +struct SchemaViewerMut<'a> { db: &'a RelationalDB, - tx: &'a mut MutTx, + schema: SchemaViewer<'a, MutTx>, } -impl CallParams for DbParams<'_> { - fn create_or_get_param(&mut self, param: &ProductValue) -> Result { +impl SchemaView for SchemaViewerMut<'_> { + fn table_id(&self, name: &str) -> Option { + self.schema.table_id(name) + } + + fn schema_for_table(&self, table_id: TableId) -> Option> { + self.schema.schema_for_table(table_id) + } + + fn rls_rules_for_table(&self, table_id: TableId) -> anyhow::Result>> { + self.schema.rls_rules_for_table(table_id) + } + + fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult { self.db - .create_or_get_params(self.tx, ¶m) + .create_or_get_params(&mut self.schema, ¶ms) .map_err(|err| TypingError::Other(err.into())) } } @@ -215,20 +226,16 @@ pub async fn run( ) -> Result { // We parse the sql statement in a mutable transaction. // If it turns out to be a query, we downgrade the tx. - let (tx, stmt) = - db.with_auto_rollback( - db.begin_mut_tx(IsolationLevel::Serializable, Workload::Sql), - |tx| match compile_sql_stmt(sql_text, &SchemaViewer::new(tx, &auth), &auth) { - Ok(Statement::Select(mut stmt)) => { - stmt.for_each_fun_call(&mut |param| { - db.create_or_get_params(tx, ¶m) - .map_err(|err| TypingError::Other(err.into())) - })?; - Ok(Statement::Select(stmt)) - } - result => result, + let (tx, stmt) = db.with_auto_rollback(db.begin_mut_tx(IsolationLevel::Serializable, Workload::Sql), |tx| { + compile_sql_stmt( + sql_text, + &mut SchemaViewerMut { + db, + schema: SchemaViewer::new(tx, &auth), }, - )?; + &auth, + ) + })?; let mut metrics = ExecutionMetrics::default(); @@ -1619,6 +1626,10 @@ pub(crate) mod tests { true, )?; let arg_id = ST_RESERVED_SEQUENCE_RANGE as u64; + assert_eq!( + run_for_testing(&db, "select view_id, param_pos, param_name FROM st_view_param")?, + vec![product![arg_id as u32, 0u16, "x"]] + ); with_auto_commit(&db, |tx| -> Result<_, DBError> { tests_utils::insert_into_view(&db, tx, table_id, None, product![arg_id + 1, 0u8, 1i64])?; @@ -1636,6 +1647,12 @@ pub(crate) mod tests { vec![product![0u8, 1i64]] ); + // We have created the internal rows for view args + assert_eq!( + run_for_testing(&db, "select id FROM st_view_arg")?, + vec![product![arg_id], product![arg_id + 1]] + ); + Ok(()) } } diff --git a/crates/core/src/sql/parser.rs b/crates/core/src/sql/parser.rs index 66d216a5590..3dde1a49266 100644 --- a/crates/core/src/sql/parser.rs +++ b/crates/core/src/sql/parser.rs @@ -19,7 +19,7 @@ impl RowLevelExpr { auth_ctx: &AuthCtx, rls: &RawRowLevelSecurityDefV9, ) -> anyhow::Result { - let (sql, _) = parse_and_type_sub(&rls.sql, &SchemaViewer::new(tx, auth_ctx), auth_ctx)?; + let (sql, _) = parse_and_type_sub(&rls.sql, &mut SchemaViewer::new(tx, auth_ctx), auth_ctx)?; let table_id = sql.return_table_id().unwrap(); let schema = tx.schema_for_table(table_id)?; diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index 0396a12f10f..842996c8305 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -472,7 +472,7 @@ impl ModuleSubscriptions { let hash = QueryHash::from_string(&sql, auth.caller(), false); let hash_with_param = QueryHash::from_string(&sql, auth.caller(), true); - let (mut_tx, _) = self.begin_mut_tx(Workload::Subscribe); + let (mut mut_tx, _) = self.begin_mut_tx(Workload::Subscribe); let existing_query = { let guard = self.subscriptions.read(); @@ -482,7 +482,7 @@ impl ModuleSubscriptions { let query = return_on_err_with_sql!( existing_query.map(Ok).unwrap_or_else(|| compile_query_with_hashes( &auth, - &*mut_tx, + &mut *mut_tx, &sql, hash, hash_with_param @@ -736,7 +736,7 @@ impl ModuleSubscriptions { } // We always get the db lock before the subscription lock to avoid deadlocks. - let (mut_tx, _tx_offset) = self.begin_mut_tx(Workload::Subscribe); + let (mut mut_tx, _tx_offset) = self.begin_mut_tx(Workload::Subscribe); let compile_timer = metrics.compilation_time.start_timer(); @@ -752,7 +752,7 @@ impl ModuleSubscriptions { super::subscription::get_all( |relational_db, tx| relational_db.get_all_tables_mut(tx).map(|schemas| schemas.into_iter()), &self.relational_db, - &*mut_tx, + &mut *mut_tx, &auth, )? .into_iter() @@ -769,7 +769,7 @@ impl ModuleSubscriptions { plans.push(unit); } else { plans.push(Arc::new( - compile_query_with_hashes(&auth, &*mut_tx, sql, hash, hash_with_param).map_err(|err| { + compile_query_with_hashes(&auth, &mut *mut_tx, sql, hash, hash_with_param).map_err(|err| { DBError::WithSql { error: Box::new(DBError::Other(err.into())), sql: sql.into(), @@ -1807,8 +1807,8 @@ mod tests { let auth = AuthCtx::for_testing(); let sql = "select * from t where id = 1"; - let tx = begin_tx(&db); - let plan = compile_read_only_query(&auth, &tx, sql)?; + let mut tx = begin_tx(&db); + let plan = compile_read_only_query(&auth, &mut tx, sql)?; let plan = Arc::new(plan); let (_, metrics) = subs.evaluate_queries(sender, &[plan], &tx, &auth, TableUpdateType::Subscribe)?; diff --git a/crates/core/src/subscription/module_subscription_manager.rs b/crates/core/src/subscription/module_subscription_manager.rs index 6b93e996ce4..dccb0b1dfc1 100644 --- a/crates/core/src/subscription/module_subscription_manager.rs +++ b/crates/core/src/subscription/module_subscription_manager.rs @@ -1681,8 +1681,8 @@ mod tests { fn compile_plan(db: &RelationalDB, sql: &str) -> ResultTest> { with_read_only(db, |tx| { let auth = AuthCtx::for_testing(); - let tx = SchemaViewer::new(&*tx, &auth); - let (plans, has_param) = SubscriptionPlan::compile(sql, &tx, &auth).unwrap(); + let mut tx = SchemaViewer::new(tx, &auth); + let (plans, has_param) = SubscriptionPlan::compile(sql, &mut tx, &auth).unwrap(); let hash = QueryHash::from_string(sql, auth.caller(), has_param); Ok(Arc::new(Plan::new(plans, hash, sql.into()))) }) diff --git a/crates/core/src/subscription/query.rs b/crates/core/src/subscription/query.rs index 968a092e5d2..31b81b3dfc3 100644 --- a/crates/core/src/subscription/query.rs +++ b/crates/core/src/subscription/query.rs @@ -42,7 +42,7 @@ pub fn is_subscribe_to_all_tables(sql: &str) -> bool { pub fn compile_read_only_queryset( relational_db: &RelationalDB, auth: &AuthCtx, - tx: &Tx, + tx: &mut Tx, input: &str, ) -> Result, DBError> { let input = input.trim(); @@ -82,13 +82,13 @@ pub fn compile_read_only_queryset( /// Compile a string into a single read-only query. /// This returns an error if the string has multiple queries or mutations. -pub fn compile_read_only_query(auth: &AuthCtx, tx: &Tx, input: &str) -> Result { +pub fn compile_read_only_query(auth: &AuthCtx, tx: &mut Tx, input: &str) -> Result { if is_whitespace_or_empty(input) { return Err(SubscriptionError::Empty.into()); } - let tx = SchemaViewer::new(tx, auth); - let (plans, has_param) = SubscriptionPlan::compile(input, &tx, auth)?; + let mut tx = SchemaViewer::new(tx, auth); + let (plans, has_param) = SubscriptionPlan::compile(input, &mut tx, auth)?; let hash = QueryHash::from_string(input, auth.caller(), has_param); Ok(Plan::new(plans, hash, input.to_owned())) } @@ -97,7 +97,7 @@ pub fn compile_read_only_query(auth: &AuthCtx, tx: &Tx, input: &str) -> Result

( auth: &AuthCtx, - tx: &Tx, + tx: &mut Tx, input: &str, hash: QueryHash, hash_with_param: QueryHash, @@ -106,8 +106,8 @@ pub fn compile_query_with_hashes( return Err(SubscriptionError::Empty.into()); } - let tx = SchemaViewer::new(tx, auth); - let (plans, has_param) = SubscriptionPlan::compile(input, &tx, auth)?; + let mut tx = SchemaViewer::new(tx, auth); + let (plans, has_param) = SubscriptionPlan::compile(input, &mut tx, auth)?; if auth.bypass_rls() || has_param { // Note that when generating hashes for queries from owners, @@ -151,7 +151,7 @@ mod tests { use crate::db::relational_db::tests_utils::{ begin_mut_tx, begin_tx, insert, with_auto_commit, with_read_only, TestDB, }; - use crate::db::relational_db::MutTx; + use crate::db::relational_db::{tests_utils, MutTx}; use crate::host::module_host::{DatabaseTableUpdate, DatabaseUpdate, UpdatesRelValue}; use crate::sql::execute::collect_result; use crate::sql::execute::tests::run_for_testing; @@ -164,6 +164,7 @@ mod tests { use itertools::Itertools; use spacetimedb_client_api_messages::websocket::{BsatnFormat, CompressableQueryUpdate, Compression}; use spacetimedb_datastore::execution_context::Workload; + use spacetimedb_datastore::system_tables::ST_RESERVED_SEQUENCE_RANGE; use spacetimedb_lib::bsatn; use spacetimedb_lib::db::auth::{StAccess, StTableType}; use spacetimedb_lib::error::ResultTest; @@ -421,9 +422,9 @@ mod tests { db.create_table_for_test("a", schema, indexes)?; db.create_table_for_test("b", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); let sql = "SELECT b.* FROM b JOIN a ON b.n = a.n WHERE b.data > 200"; - let result = compile_read_only_query(&AuthCtx::for_testing(), &tx, sql); + let result = compile_read_only_query(&AuthCtx::for_testing(), &mut tx, sql); assert!(result.is_ok()); Ok(()) } @@ -454,10 +455,10 @@ mod tests { }; db.commit_tx(tx)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); let sql = "select * from test where b = 3"; - let mut exp = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql)?; + let mut exp = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql)?; let Some(CrudExpr::Query(query)) = exp.pop() else { panic!("unexpected query {:#?}", exp[0]); @@ -609,8 +610,8 @@ mod tests { AND MobileEntityState.location_z > 96000 \ AND MobileEntityState.location_z < 192000"; - let tx = begin_tx(&db); - let qset = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, sql_query)?; + let mut tx = begin_tx(&db); + let qset = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, sql_query)?; for q in qset { let result = run_query( @@ -684,7 +685,7 @@ mod tests { let indexes = &[ColId(0), ColId(1)]; db.create_table_for_test("rhs", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // All single table queries are supported let scans = [ @@ -696,7 +697,7 @@ mod tests { "SELECT * FROM lhs WHERE id > 5", ]; for scan in scans { - let expr = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, scan)? + let expr = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, scan)? .pop() .unwrap(); assert_eq!(expr.kind(), Supported::Select, "{scan}\n{expr:#?}"); @@ -705,7 +706,7 @@ mod tests { // Only index semijoins are supported let joins = ["SELECT lhs.* FROM lhs JOIN rhs ON lhs.id = rhs.id WHERE rhs.y < 10"]; for join in joins { - let expr = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, join)? + let expr = compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, join)? .pop() .unwrap(); assert_eq!(expr.kind(), Supported::Semijoin, "{join}\n{expr:#?}"); @@ -718,7 +719,7 @@ mod tests { "SELECT * FROM lhs JOIN rhs ON lhs.id = rhs.id WHERE lhs.x < 10", ]; for join in joins { - match compile_read_only_queryset(&db, &AuthCtx::for_testing(), &tx, join) { + match compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, join) { Err(DBError::Subscription(SubscriptionError::Unsupported(_)) | DBError::TypeError(_)) => (), x => panic!("Unexpected: {x:?}"), } @@ -756,10 +757,10 @@ mod tests { fn compile_query(db: &RelationalDB) -> ResultTest { with_read_only(db, |tx| { let auth = AuthCtx::for_testing(); - let tx = SchemaViewer::new(tx, &auth); + let mut tx = SchemaViewer::new(tx, &auth); // Should be answered using an index semijion let sql = "select lhs.* from lhs join rhs on lhs.id = rhs.id where rhs.y >= 2 and rhs.y <= 4"; - Ok(SubscriptionPlan::compile(sql, &tx, &auth) + Ok(SubscriptionPlan::compile(sql, &mut tx, &auth) .map(|(mut plans, _)| { assert_eq!(plans.len(), 1); plans.pop().unwrap() @@ -781,10 +782,10 @@ mod tests { fn compile_query(db: &RelationalDB) -> ResultTest { with_read_only(db, |tx| { let auth = AuthCtx::for_testing(); - let tx = SchemaViewer::new(tx, &auth); + let mut tx = SchemaViewer::new(tx, &auth); // Should be answered using an index semijion let sql = "select lhs.* from lhs join rhs on lhs.id = rhs.id where lhs.x >= 5 and lhs.x <= 7"; - Ok(SubscriptionPlan::compile(sql, &tx, &auth) + Ok(SubscriptionPlan::compile(sql, &mut tx, &auth) .map(|(mut plans, _)| { assert_eq!(plans.len(), 1); plans.pop().unwrap() @@ -1447,4 +1448,36 @@ mod tests { assert_eq!(metrics.index_seeks, 8); Ok(()) } + + // Verify calling views with params + // TODO: All testing use the old query compiler, so we can't test this yet. + #[test] + fn test_view_params() -> ResultTest<()> { + let db = TestDB::durable()?; + let schema = [("a", AlgebraicType::U8), ("b", AlgebraicType::I64)]; + let (_view_id, table_id) = tests_utils::create_view_for_test( + &db, + "my_view", + &schema, + ProductType::from([("x", AlgebraicType::U8)]), + true, + )?; + let arg_id = ST_RESERVED_SEQUENCE_RANGE as u64; + + with_auto_commit(&db, |tx| -> Result<_, DBError> { + tests_utils::insert_into_view(&db, tx, table_id, None, product![arg_id + 1, 0u8, 1i64])?; + tests_utils::insert_into_view(&db, tx, table_id, None, product![arg_id, 1u8, 2i64])?; + Ok(()) + })?; + + let mut tx = begin_tx(&db); + + let err = + compile_read_only_queryset(&db, &AuthCtx::for_testing(), &mut tx, "SELECT * FROM my_view(1)").unwrap_err(); + assert_eq!( + err.to_string(), + "InternalError: Read-only queries cannot create parameters".to_string() + ); + Ok(()) + } } diff --git a/crates/core/src/subscription/subscription.rs b/crates/core/src/subscription/subscription.rs index 1fb13e8ee56..5b14700c2ff 100644 --- a/crates/core/src/subscription/subscription.rs +++ b/crates/core/src/subscription/subscription.rs @@ -615,7 +615,7 @@ impl AuthAccess for ExecutionSet { pub(crate) fn get_all( get_all_tables: F, relational_db: &RelationalDB, - tx: &T, + tx: &mut T, auth: &AuthCtx, ) -> Result, DBError> where @@ -627,8 +627,8 @@ where .filter(|t| t.table_type == StTableType::User && auth.has_read_access(t.table_access)) .map(|schema| { let sql = format!("SELECT * FROM {}", schema.table_name); - let tx = SchemaViewer::new(tx, auth); - SubscriptionPlan::compile(&sql, &tx, auth).map(|(plans, has_param)| { + let mut tx = SchemaViewer::new(tx, auth); + SubscriptionPlan::compile(&sql, &mut tx, auth).map(|(plans, has_param)| { Plan::new( plans, QueryHash::from_string( @@ -701,11 +701,11 @@ mod tests { let indexes = &[0.into(), 1.into()]; let rhs_id = db.create_table_for_test("rhs", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should generate an index join since there is an index on `lhs.b`. // Should push the sargable range condition into the index join's probe side. let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3"; - let exp = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql)?.remove(0); + let exp = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql)?.remove(0); let CrudExpr::Query(mut expr) = exp else { panic!("unexpected result from compilation: {exp:#?}"); @@ -781,11 +781,11 @@ mod tests { let indexes = &[0.into(), 1.into()]; let _ = db.create_table_for_test("rhs", schema, indexes)?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should generate an index join since there is an index on `lhs.b`. // Should push the sargable range condition into the index join's probe side. let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3"; - let exp = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql)?.remove(0); + let exp = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql)?.remove(0); let CrudExpr::Query(mut expr) = exp else { panic!("unexpected result from compilation: {exp:#?}"); @@ -865,12 +865,12 @@ mod tests { .create_table_for_test("rhs", schema, indexes) .expect("Failed to create_table_for_test rhs"); - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); // Should generate an index join since there is an index on `lhs.b`. // Should push the sargable range condition into the index join's probe side. let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3"; - let exp = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql) + let exp = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql) .expect("Failed to compile_sql") .remove(0); diff --git a/crates/core/src/util/slow.rs b/crates/core/src/util/slow.rs index ea6701658cc..17522e62b7a 100644 --- a/crates/core/src/util/slow.rs +++ b/crates/core/src/util/slow.rs @@ -65,14 +65,14 @@ mod tests { use spacetimedb_vm::relation::MemTable; fn run_query(db: &Arc, sql: String) -> ResultTest { - let tx = begin_tx(db); - let q = compile_sql(db, &AuthCtx::for_testing(), &tx, &sql)?; + let mut tx = begin_tx(db); + let q = compile_sql(db, &AuthCtx::for_testing(), &mut tx, &sql)?; Ok(execute_for_testing(db, &sql, q)?.pop().unwrap()) } fn run_query_write(db: &Arc, sql: String) -> ResultTest<()> { - let tx = begin_tx(db); - let q = compile_sql(db, &AuthCtx::for_testing(), &tx, &sql)?; + let mut tx = begin_tx(db); + let q = compile_sql(db, &AuthCtx::for_testing(), &mut tx, &sql)?; drop(tx); execute_for_testing(db, &sql, q)?; @@ -92,10 +92,10 @@ mod tests { } Ok(()) })?; - let tx = begin_tx(&db); + let mut tx = begin_tx(&db); let sql = "select * from test where x > 0"; - let q = compile_sql(&db, &AuthCtx::for_testing(), &tx, sql)?; + let q = compile_sql(&db, &AuthCtx::for_testing(), &mut tx, sql)?; let slow = SlowQueryLogger::new(sql, Some(Duration::from_millis(1)), tx.ctx.workload()); diff --git a/crates/expr/src/check.rs b/crates/expr/src/check.rs index 14e4c59070b..c6190e2b106 100644 --- a/crates/expr/src/check.rs +++ b/crates/expr/src/check.rs @@ -9,12 +9,12 @@ use super::{ type_expr, type_proj, type_select, }; use crate::errors::{TableFunc, UnexpectedFunctionType}; -use crate::expr::{Expr, LeftDeepJoin, ProjectList, ProjectName, Relvar}; +use crate::expr::{Expr, FieldProject, LeftDeepJoin, ProjectList, ProjectName, Relvar}; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::AlgebraicType; use spacetimedb_primitives::{ArgId, TableId}; use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type; -use spacetimedb_sats::ProductValue; +use spacetimedb_sats::{AlgebraicValue, ProductValue}; use spacetimedb_schema::schema::TableOrViewSchema; use spacetimedb_sql_parser::ast::{BinOp, SqlExpr, SqlLiteral}; use spacetimedb_sql_parser::{ @@ -35,9 +35,7 @@ pub trait SchemaView { self.table_id(name).and_then(|table_id| self.schema_for_table(table_id)) } - fn get_or_create_params(&self, params: &ProductValue) -> TypingResult { - Ok(ArgId::SENTINEL) - } + fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult; } #[derive(Default)] @@ -60,9 +58,9 @@ pub trait TypeChecker { type Ast; type Set; - fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult; + fn type_ast(ast: Self::Ast, tx: &mut impl SchemaView) -> TypingResult; - fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult; + fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &mut impl SchemaView) -> TypingResult; fn type_view_params( schema: &TableOrViewSchema, @@ -152,25 +150,35 @@ pub trait TypeChecker { } fn type_params( + tx: &mut impl SchemaView, from: RelExpr, schema: Arc, alias: Box, params: Option, - ) -> RelExpr { + ) -> TypingResult { match params { - None => from, - Some(args) => RelExpr::FunCall( - Relvar { - schema, - alias, - delta: None, - }, - args, - ), + None => Ok(from), + Some(args) => { + let new_arg_id = tx.get_or_create_params(args)?; + let arg_id_col = schema.inner().get_column_by_name("arg_id").unwrap().col_pos; + + Ok(RelExpr::Select( + Box::new(from), + Expr::BinOp( + BinOp::Eq, + Box::new(Expr::Field(FieldProject { + table: alias, + field: arg_id_col.idx(), + ty: AlgebraicType::U64, + })), + Box::new(Expr::Value(AlgebraicValue::U64(new_arg_id.0), AlgebraicType::U64)), + ), + )) + } } } - fn type_from(from: SqlFrom, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult { + fn type_from(from: SqlFrom, vars: &mut Relvars, tx: &mut impl SchemaView) -> TypingResult { match from { SqlFrom::Expr(SqlIdent(name), SqlIdent(alias)) => { let schema = Self::type_relvar(tx, &name)?; @@ -202,7 +210,7 @@ pub trait TypeChecker { } let schema = Self::type_relvar(tx, &name)?; let arg = Self::type_view_params(&schema, vars, params)?; - let lhs = Box::new(Self::type_params(join, schema.clone(), alias.clone(), arg)); + let lhs = Box::new(Self::type_params(tx, join, schema.clone(), alias.clone(), arg)?); let rhs = Relvar { schema, @@ -237,7 +245,7 @@ pub trait TypeChecker { delta: None, }); - Ok(Self::type_params(from, schema, alias, arg)) + Self::type_params(tx, from, schema, alias, arg) } } } @@ -256,11 +264,11 @@ impl TypeChecker for SubChecker { type Ast = SqlSelect; type Set = SqlSelect; - fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult { + fn type_ast(ast: Self::Ast, tx: &mut impl SchemaView) -> TypingResult { Self::type_set(ast, &mut Relvars::default(), tx) } - fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult { + fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &mut impl SchemaView) -> TypingResult { match ast { SqlSelect { project, @@ -283,7 +291,7 @@ impl TypeChecker for SubChecker { } /// Parse and type check a subscription query -pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult<(ProjectName, bool)> { +pub fn parse_and_type_sub(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> TypingResult<(ProjectName, bool)> { let ast = parse_subscription(sql)?; let has_param = ast.has_parameter(); let ast = ast.resolve_sender(auth.caller()); @@ -303,15 +311,16 @@ fn expect_table_type(expr: ProjectList) -> TypingResult { pub mod test_utils { use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, ProductType}; - use spacetimedb_primitives::TableId; - use spacetimedb_sats::AlgebraicType; + use spacetimedb_primitives::{ArgId, TableId}; + use spacetimedb_sats::{AlgebraicType, ProductValue}; use spacetimedb_schema::{ def::ModuleDef, schema::{Schema, TableOrViewSchema, TableSchema}, }; + use std::collections::HashMap; use std::sync::Arc; - use super::SchemaView; + use super::{SchemaView, TypingResult}; pub struct ViewInfo<'a> { pub(crate) name: &'a str, pub(crate) columns: &'a [(&'a str, AlgebraicType)], @@ -333,7 +342,38 @@ pub mod test_utils { builder.finish().try_into().expect("failed to generate module def") } - pub struct SchemaViewer(pub ModuleDef); + pub struct MockCallParams { + counter: u64, + params: HashMap, + } + + impl Default for MockCallParams { + fn default() -> Self { + Self::new() + } + } + + impl MockCallParams { + pub fn new() -> Self { + Self { + counter: 0, + params: HashMap::new(), + } + } + + pub fn get_or_insert(&mut self, value: ProductValue) -> ArgId { + if let Some(existing) = self.params.get(&value) { + *existing + } else { + self.counter += 1; + let arg_id = ArgId(self.counter - 1); + self.params.insert(value, arg_id); + arg_id + } + } + } + + pub struct SchemaViewer(pub ModuleDef, pub MockCallParams); impl SchemaView for SchemaViewer { fn table_id(&self, name: &str) -> Option { @@ -370,6 +410,10 @@ pub mod test_utils { fn rls_rules_for_table(&self, _: TableId) -> anyhow::Result>> { Ok(vec![]) } + + fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult { + Ok(self.1.get_or_insert(params)) + } } } @@ -423,13 +467,13 @@ mod tests { } /// A wrapper around [super::parse_and_type_sub] that takes a dummy [AuthCtx] - fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult { + fn parse_and_type_sub(sql: &str, tx: &mut impl SchemaView) -> TypingResult { super::parse_and_type_sub(sql, tx, &AuthCtx::for_testing()).map(|(plan, _)| plan) } #[test] fn valid_literals() { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); struct TestCase { sql: &'static str, @@ -498,27 +542,27 @@ mod tests { msg: "timestamp ms with timezone", }, ] { - let result = parse_and_type_sub(sql, &tx); + let result = parse_and_type_sub(sql, &mut tx); assert!(result.is_ok(), "name: {}, error: {}", msg, result.unwrap_err()); } } #[test] fn valid_literals_for_type() { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); for ty in [ "i8", "u8", "i16", "u16", "i32", "u32", "i64", "u64", "f32", "f64", "i128", "u128", "i256", "u256", ] { let sql = format!("select * from t where {ty} = 127"); - let result = parse_and_type_sub(&sql, &tx); + let result = parse_and_type_sub(&sql, &mut tx); assert!(result.is_ok(), "Failed to parse {ty}: {}", result.unwrap_err()); } } #[test] fn invalid_literals() { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); struct TestCase { sql: &'static str, @@ -547,14 +591,14 @@ mod tests { msg: "Float as integer", }, ] { - let result = parse_and_type_sub(sql, &tx); + let result = parse_and_type_sub(sql, &mut tx); assert!(result.is_err(), "{msg}"); } } #[test] fn valid() { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); struct TestCase { sql: &'static str, @@ -611,14 +655,14 @@ mod tests { msg: "Type inner join + projection", }, ] { - let result = parse_and_type_sub(sql, &tx); + let result = parse_and_type_sub(sql, &mut tx); assert!(result.is_ok(), "{msg}"); } } #[test] fn invalid() { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); struct TestCase { sql: &'static str, @@ -683,7 +727,7 @@ mod tests { msg: "Columns must be qualified in join expressions", }, ] { - let result = parse_and_type_sub(sql, &tx); + let result = parse_and_type_sub(sql, &mut tx); assert!(result.is_err(), "{msg}"); } } diff --git a/crates/expr/src/errors.rs b/crates/expr/src/errors.rs index 4a91980fde1..3129ab476f3 100644 --- a/crates/expr/src/errors.rs +++ b/crates/expr/src/errors.rs @@ -172,6 +172,8 @@ pub enum TypingError { FilterReturnType(#[from] FilterReturnType), #[error(transparent)] TableFunc(#[from] TableFunc), + #[error("InternalError: Read-only queries cannot create parameters")] + ParamsReadOnly, #[error(transparent)] Other(#[from] anyhow::Error), } diff --git a/crates/expr/src/expr.rs b/crates/expr/src/expr.rs index f6556f34561..5ae8484c675 100644 --- a/crates/expr/src/expr.rs +++ b/crates/expr/src/expr.rs @@ -238,35 +238,6 @@ impl ProjectList { Self::Agg(_, _, name, ty) => f(name, ty), } } - - /// Iterate over the function calls in this projection list - pub fn for_each_fun_call( - &mut self, - f: &mut impl FnMut(ProductValue) -> Result, - ) -> Result<(), TypingError> { - match self { - ProjectList::Name(input) => { - for proj in input { - match proj { - ProjectName::None(expr) | ProjectName::Some(expr, _) => { - expr.for_each_fun_call(f)?; - } - } - } - } - ProjectList::List(input, _) => { - for expr in input { - expr.for_each_fun_call(f)?; - } - } - ProjectList::Limit(input, _) => { - input.for_each_fun_call(f)?; - } - ProjectList::Agg(_, _, _, _) => {} - } - - Ok(()) - } } /// A logical relational expression @@ -398,31 +369,6 @@ impl RelExpr { _ => None, } } - - fn for_each_fun_call( - &mut self, - f: &mut impl FnMut(ProductValue) -> Result, - ) -> Result<(), TypingError> { - // For function calls, we need to filter by the argument id - if let RelExpr::FunCall(relvar, param) = self { - let new_arg_id = f(param.clone())?; - let arg_id_col = relvar.schema.inner().get_column_by_name("arg_id").unwrap().col_pos; - - *self = RelExpr::Select( - Box::new(RelExpr::RelVar(relvar.clone())), - Expr::BinOp( - BinOp::Eq, - Box::new(Expr::Field(FieldProject { - table: relvar.alias.clone(), - field: arg_id_col.idx(), - ty: AlgebraicType::U64, - })), - Box::new(Expr::Value(AlgebraicValue::U64(new_arg_id.0), AlgebraicType::U64)), - ), - ); - } - Ok(()) - } } /// A left deep binary cross product diff --git a/crates/expr/src/rls.rs b/crates/expr/src/rls.rs index 89cd9e5a2a7..5a1cc312dcb 100644 --- a/crates/expr/src/rls.rs +++ b/crates/expr/src/rls.rs @@ -12,7 +12,7 @@ use crate::{ /// The main driver of RLS resolution for subscription queries. /// Mainly a wrapper around [resolve_views_for_expr]. pub fn resolve_views_for_sub( - tx: &impl SchemaView, + tx: &mut impl SchemaView, expr: ProjectName, auth: &AuthCtx, has_param: &mut bool, @@ -54,7 +54,11 @@ pub fn resolve_views_for_sub( /// The main driver of RLS resolution for sql queries. /// Mainly a wrapper around [resolve_views_for_expr]. -pub fn resolve_views_for_sql(tx: &impl SchemaView, expr: ProjectList, auth: &AuthCtx) -> anyhow::Result { +pub fn resolve_views_for_sql( + tx: &mut impl SchemaView, + expr: ProjectList, + auth: &AuthCtx, +) -> anyhow::Result { // RLS does not apply to the database owner if auth.bypass_rls() { return Ok(expr); @@ -62,44 +66,41 @@ pub fn resolve_views_for_sql(tx: &impl SchemaView, expr: ProjectList, auth: &Aut // The subscription language is a subset of the sql language. // Use the subscription helper if this is a compliant expression. // Use the generic resolver otherwise. - let resolve_for_sub = |expr| resolve_views_for_sub(tx, expr, auth, &mut false); - let resolve_for_sql = |expr| { - resolve_views_for_expr( - // Use all default values - tx, - expr, - None, - Rc::new(ResolveList::None), - &mut false, - &mut 0, - auth, - ) - }; match expr { - ProjectList::Limit(expr, n) => Ok(ProjectList::Limit(Box::new(resolve_views_for_sql(tx, *expr, auth)?), n)), + ProjectList::Limit(expr, n) => { + let expr = resolve_views_for_sql(tx, *expr, auth)?; + Ok(ProjectList::Limit(Box::new(expr), n)) + } + ProjectList::Name(exprs) => Ok(ProjectList::Name( exprs .into_iter() - .map(resolve_for_sub) + .map(|expr| resolve_views_for_sub(tx, expr, auth, &mut false)) .collect::, _>>()? .into_iter() .flatten() .collect(), )), + ProjectList::List(exprs, fields) => Ok(ProjectList::List( exprs .into_iter() - .map(resolve_for_sql) + .map(|expr| { + resolve_views_for_expr(tx, expr, None, Rc::new(ResolveList::None), &mut false, &mut 0, auth) + }) .collect::, _>>()? .into_iter() .flatten() .collect(), fields, )), + ProjectList::Agg(exprs, AggType::Count, name, ty) => Ok(ProjectList::Agg( exprs .into_iter() - .map(resolve_for_sql) + .map(|expr| { + resolve_views_for_expr(tx, expr, None, Rc::new(ResolveList::None), &mut false, &mut 0, auth) + }) .collect::, _>>()? .into_iter() .flatten() @@ -203,7 +204,7 @@ impl ResolveList { /// i.e. the subtree rooted at `a` in the above example, /// must be pushed below the leftmost leaf node of the view expansion. fn resolve_views_for_expr( - tx: &impl SchemaView, + tx: &mut impl SchemaView, view: RelExpr, return_table_id: Option, resolving: Rc, @@ -473,21 +474,23 @@ mod tests { use pretty_assertions as pretty; use spacetimedb_lib::{identity::AuthCtx, AlgebraicType, AlgebraicValue, Identity, ProductType}; - use spacetimedb_primitives::TableId; + use spacetimedb_primitives::{ArgId, TableId}; + use spacetimedb_sats::ProductValue; use spacetimedb_schema::{ def::ModuleDef, schema::{Schema, TableOrViewSchema, TableSchema}, }; use spacetimedb_sql_parser::ast::BinOp; + use super::resolve_views_for_sub; + use crate::check::test_utils::MockCallParams; + use crate::check::TypingResult; use crate::{ check::{parse_and_type_sub, test_utils::build_module_def, SchemaView}, expr::{Expr, FieldProject, LeftDeepJoin, ProjectName, RelExpr, Relvar}, }; - use super::resolve_views_for_sub; - - pub struct SchemaViewer(pub ModuleDef); + pub struct SchemaViewer(pub ModuleDef, pub MockCallParams); impl SchemaView for SchemaViewer { fn table_id(&self, name: &str) -> Option { @@ -526,6 +529,10 @@ mod tests { _ => Ok(vec![]), } } + + fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult { + Ok(self.1.get_or_insert(params)) + } } fn module_def() -> ModuleDef { @@ -549,17 +556,17 @@ mod tests { } /// Parse, type check, and resolve RLS rules - fn resolve(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> anyhow::Result> { + fn resolve(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> anyhow::Result> { let (expr, _) = parse_and_type_sub(sql, tx, auth)?; resolve_views_for_sub(tx, expr, auth, &mut false) } #[test] fn test_rls_for_owner() -> anyhow::Result<()> { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); let auth = AuthCtx::new(Identity::ONE, Identity::ONE); let sql = "select * from users"; - let resolved = resolve(sql, &tx, &auth)?; + let resolved = resolve(sql, &mut tx, &auth)?; let users_schema = tx.schema("users").unwrap(); @@ -577,10 +584,10 @@ mod tests { #[test] fn test_rls_for_non_owner() -> anyhow::Result<()> { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); let auth = AuthCtx::new(Identity::ZERO, Identity::ONE); let sql = "select * from users"; - let resolved = resolve(sql, &tx, &auth)?; + let resolved = resolve(sql, &mut tx, &auth)?; let users_schema = tx.schema("users").unwrap(); @@ -612,10 +619,10 @@ mod tests { #[test] fn test_multiple_rls_rules_for_table() -> anyhow::Result<()> { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); let auth = AuthCtx::new(Identity::ZERO, Identity::ONE); let sql = "select * from player where level_num = 5"; - let resolved = resolve(sql, &tx, &auth)?; + let resolved = resolve(sql, &mut tx, &auth)?; let users_schema = tx.schema("users").unwrap(); let admins_schema = tx.schema("admins").unwrap(); diff --git a/crates/expr/src/statement.rs b/crates/expr/src/statement.rs index 1fc882cf5ca..317cd72a7a8 100644 --- a/crates/expr/src/statement.rs +++ b/crates/expr/src/statement.rs @@ -394,11 +394,11 @@ impl TypeChecker for SqlChecker { type Ast = SqlSelect; type Set = SqlSelect; - fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult { + fn type_ast(ast: Self::Ast, tx: &mut impl SchemaView) -> TypingResult { Self::type_set(ast, &mut Relvars::default(), tx) } - fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult { + fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &mut impl SchemaView) -> TypingResult { match ast { SqlSelect { project, @@ -439,7 +439,7 @@ impl TypeChecker for SqlChecker { } } -pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult { +pub fn parse_and_type_sql(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> TypingResult { match parse_sql(sql)?.resolve_sender(auth.caller()) { SqlAst::Select(ast) => Ok(Statement::Select(SqlChecker::type_ast(ast, tx)?)), SqlAst::Insert(insert) => Ok(Statement::DML(DML::Insert(type_insert(insert, tx)?))), @@ -451,7 +451,7 @@ pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> Ty } /// Parse and type check a *general* query into a [StatementCtx]. -pub fn compile_sql_stmt<'a>(sql: &'a str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult> { +pub fn compile_sql_stmt<'a>(sql: &'a str, tx: &mut impl SchemaView, auth: &AuthCtx) -> TypingResult> { let statement = parse_and_type_sql(sql, tx, auth)?; Ok(StatementCtx { statement, @@ -520,13 +520,13 @@ mod tests { } /// A wrapper around [super::parse_and_type_sql] that takes a dummy [AuthCtx] - fn parse_and_type_sql(sql: &str, tx: &impl SchemaView) -> TypingResult { + fn parse_and_type_sql(sql: &str, tx: &mut impl SchemaView) -> TypingResult { super::parse_and_type_sql(sql, tx, &AuthCtx::for_testing()) } #[test] fn valid() { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); for sql in [ "select str from t", @@ -534,14 +534,14 @@ mod tests { "select t.str, arr from t", "select * from t limit 5", ] { - let result = parse_and_type_sql(sql, &tx); + let result = parse_and_type_sql(sql, &mut tx); assert!(result.is_ok()); } } #[test] fn invalid() { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); for sql in [ // Unqualified columns in a join @@ -551,7 +551,7 @@ mod tests { // Unqualified name in join expression "select t.* from t join s on t.u32 = s.u32 where bytes = 0xABCD", ] { - let result = parse_and_type_sql(sql, &tx); + let result = parse_and_type_sql(sql, &mut tx); assert!(result.is_err()); } } @@ -581,7 +581,7 @@ mod tests { #[test] fn views() { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); struct TestCase { sql: &'static str, @@ -606,7 +606,7 @@ mod tests { msg: "Function call returning view with parameters", }, ] { - let result = parse_and_type_sql(sql, &tx).inspect_err(|e| { + let result = parse_and_type_sql(sql, &mut tx).inspect_err(|e| { panic!("Expected OK for `{sql}` but got error: {e}"); }); assert!(result.is_ok(), "{msg}: {sql}"); @@ -630,14 +630,14 @@ mod tests { msg: "`v` does not take parameters", }, ] { - let result = parse_and_type_sql(sql, &tx); + let result = parse_and_type_sql(sql, &mut tx); assert!(result.is_err(), "{msg}"); } } #[test] fn params_validation() { - let tx = SchemaViewer(module_def()); + let mut tx = SchemaViewer(module_def(), Default::default()); struct TestCase { sql: &'static str, @@ -670,7 +670,7 @@ mod tests { msg: "Unexpected function type. Expected: (U32, String) != Inferred: (U32, String, Num?)", }, ] { - let result = parse_and_type_sql(sql, &tx); + let result = parse_and_type_sql(sql, &mut tx); if msg == "Correct parameters" { assert!(result.is_ok(), "{msg}: {sql}"); } else if let Err(err) = &result { diff --git a/crates/physical-plan/src/plan.rs b/crates/physical-plan/src/plan.rs index 597ab6dd6eb..20b17d1e642 100644 --- a/crates/physical-plan/src/plan.rs +++ b/crates/physical-plan/src/plan.rs @@ -1402,6 +1402,7 @@ mod tests { use std::sync::Arc; use pretty_assertions::assert_eq; + use spacetimedb_expr::check::test_utils::MockCallParams; use spacetimedb_expr::{ check::{SchemaView, TypingResult}, expr::ProjectName, @@ -1410,9 +1411,9 @@ mod tests { use spacetimedb_lib::{ db::auth::{StAccess, StTableType}, identity::AuthCtx, - AlgebraicType, AlgebraicValue, + AlgebraicType, AlgebraicValue, ProductValue, }; - use spacetimedb_primitives::{ColId, ColList, ColSet, TableId, ViewId}; + use spacetimedb_primitives::{ArgId, ColId, ColList, ColSet, TableId, ViewId}; use spacetimedb_schema::def::ViewParamDefSimple; use spacetimedb_schema::identifier::Identifier; use spacetimedb_schema::schema::ViewDefInfo; @@ -1431,6 +1432,7 @@ mod tests { struct SchemaViewer { schemas: Vec>, + params: MockCallParams, } impl SchemaView for SchemaViewer { @@ -1448,6 +1450,10 @@ mod tests { fn rls_rules_for_table(&self, _: TableId) -> anyhow::Result>> { Ok(vec![]) } + + fn get_or_create_params(&mut self, params: ProductValue) -> TypingResult { + Ok(self.params.get_or_insert(params)) + } } fn schema_with_params( @@ -1534,7 +1540,7 @@ mod tests { } /// A wrapper around [spacetimedb_expr::check::parse_and_type_sub] that takes a dummy [AuthCtx] - fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult { + fn parse_and_type_sub(sql: &str, tx: &mut impl SchemaView) -> TypingResult { spacetimedb_expr::check::parse_and_type_sub(sql, tx, &AuthCtx::for_testing()).map(|(plan, _)| plan) } @@ -1552,14 +1558,15 @@ mod tests { Some(0), )); - let db = SchemaViewer { + let mut db = SchemaViewer { schemas: vec![t.clone()], + params: Default::default(), }; let sql = "select * from t"; let auth = AuthCtx::for_testing(); - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); match pp { @@ -1584,14 +1591,15 @@ mod tests { Some(0), )); - let db = SchemaViewer { + let mut db = SchemaViewer { schemas: vec![t.clone()], + params: Default::default(), }; let sql = "select * from t where x = 5"; let auth = AuthCtx::for_testing(); - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); match pp { @@ -1678,8 +1686,9 @@ mod tests { Some(0), )); - let db = SchemaViewer { + let mut db = SchemaViewer { schemas: vec![u.clone(), l.clone(), b.clone()], + params: Default::default(), }; let sql = " @@ -1691,7 +1700,7 @@ mod tests { where u.identity = 5 "; let auth = AuthCtx::for_testing(); - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); // Plan: @@ -1870,8 +1879,9 @@ mod tests { Some(0), )); - let db = SchemaViewer { + let mut db = SchemaViewer { schemas: vec![m.clone(), w.clone(), p.clone()], + params: Default::default(), }; let sql = " @@ -1884,7 +1894,7 @@ mod tests { where 5 = m.employee and 5 = v.employee "; let auth = AuthCtx::for_testing(); - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); // Plan: @@ -2052,13 +2062,14 @@ mod tests { None, )); - let db = SchemaViewer { + let mut db = SchemaViewer { schemas: vec![t.clone()], + params: Default::default(), }; let sql = "select * from t where x = 3 and y = 4 and z = 5"; let auth = AuthCtx::for_testing(); - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); // Select index on (x, y, z) @@ -2081,7 +2092,7 @@ mod tests { // Test permutations of the same query let sql = "select * from t where z = 5 and y = 4 and x = 3"; - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); match pp { @@ -2102,7 +2113,7 @@ mod tests { }; let sql = "select * from t where x = 3 and y = 4"; - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); // Select index on x @@ -2130,7 +2141,7 @@ mod tests { }; let sql = "select * from t where w = 5 and x = 4"; - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); // Select index on x @@ -2158,7 +2169,7 @@ mod tests { }; let sql = "select * from t where y = 1"; - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); // Do not select index on (y, z) @@ -2173,7 +2184,7 @@ mod tests { // Select index on [y, z] let sql = "select * from t where y = 1 and z = 2"; - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); match pp { @@ -2192,7 +2203,7 @@ mod tests { // Check permutations of the same query let sql = "select * from t where z = 2 and y = 1"; - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); match pp { @@ -2211,7 +2222,7 @@ mod tests { // Select index on (y, z) and filter on (w) let sql = "select * from t where w = 1 and y = 2 and z = 3"; - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); let plan = match pp { @@ -2251,12 +2262,13 @@ mod tests { None, )); - let db = SchemaViewer { + let mut db = SchemaViewer { schemas: vec![t.clone()], + params: Default::default(), }; - let compile = |sql| { - let stmt = parse_and_type_sql(sql, &db, &AuthCtx::for_testing()).unwrap(); + let mut compile = |sql| { + let stmt = parse_and_type_sql(sql, &mut db, &AuthCtx::for_testing()).unwrap(); let Statement::Select(select) = stmt else { unreachable!() }; @@ -2322,20 +2334,20 @@ mod tests { Some(&[("param_id", AlgebraicType::U64)]), )); - let db = SchemaViewer { + let mut db = SchemaViewer { schemas: vec![t.clone(), v.clone()], + params: Default::default(), }; let sql = "select * from v(0)"; let auth = AuthCtx::for_testing(); - let lp = parse_and_type_sub(sql, &db).unwrap(); - dbg!(&lp); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); - dbg!(&pp); match pp { ProjectPlan::None(PhysicalPlan::Filter(input, PhysicalExpr::BinOp(BinOp::Eq, field, value))) => { - assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 1, .. }))); + // This is the internal parameter filter + assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 0, .. }))); assert!(matches!(*value, PhysicalExpr::Value(AlgebraicValue::U64(0)))); match *input { @@ -2349,12 +2361,12 @@ mod tests { }; let sql = "select * from v(0) as x JOIN t ON x.id = t.id"; - let lp = parse_and_type_sub(sql, &db).unwrap(); + let lp = parse_and_type_sub(sql, &mut db).unwrap(); let pp = compile_select(lp).optimize(&auth).unwrap(); match pp { ProjectPlan::None(PhysicalPlan::Filter(_, PhysicalExpr::BinOp(BinOp::Eq, field, value))) => { - assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 1, .. }))); + assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 0, .. }))); assert!(matches!(*value, PhysicalExpr::Value(AlgebraicValue::U64(0)))); } proj => panic!("unexpected project: {proj:#?}"), diff --git a/crates/query/src/lib.rs b/crates/query/src/lib.rs index 96efff60ddb..7647efc7444 100644 --- a/crates/query/src/lib.rs +++ b/crates/query/src/lib.rs @@ -4,8 +4,6 @@ use spacetimedb_execution::{ pipelined::ProjectListExecutor, Datastore, DeltaStore, }; -use spacetimedb_expr::errors::TypingError; -use spacetimedb_expr::expr::CallParams; use spacetimedb_expr::{ check::{parse_and_type_sub, SchemaView}, expr::ProjectList, @@ -17,30 +15,15 @@ use spacetimedb_physical_plan::{ compile::{compile_dml_plan, compile_select, compile_select_list}, plan::{ProjectListPlan, ProjectPlan}, }; -use spacetimedb_primitives::{ArgId, TableId}; -use std::collections::HashMap; +use spacetimedb_primitives::TableId; /// DIRTY HACK ALERT: Maximum allowed length, in UTF-8 bytes, of SQL queries. /// Any query longer than this will be rejected. /// This prevents a stack overflow when compiling queries with deeply-nested `AND` and `OR` conditions. const MAX_SQL_LENGTH: usize = 50_000; - -pub trait CallParamsExt { - fn get_arg(&self, params: &ProductValue) -> Result; -} - -pub struct MockCallParams { - params: HashMap, -} -impl CallParamsExt for MockCallParams { - fn get_arg(&self, params: &ProductValue) -> Result { - todo!() - } -} - pub fn compile_subscription( sql: &str, - tx: &impl SchemaView, + tx: &mut impl SchemaView, auth: &AuthCtx, ) -> Result<(Vec, TableId, Box, bool)> { if sql.len() > MAX_SQL_LENGTH { @@ -72,7 +55,7 @@ pub fn compile_subscription( } /// A utility for parsing and type checking a sql statement -pub fn compile_sql_stmt(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> Result { +pub fn compile_sql_stmt(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> Result { if sql.len() > MAX_SQL_LENGTH { bail!("SQL query exceeds maximum allowed length: \"{sql:.120}...\"") } diff --git a/crates/subscription/src/lib.rs b/crates/subscription/src/lib.rs index d0f4668459b..4b531470460 100644 --- a/crates/subscription/src/lib.rs +++ b/crates/subscription/src/lib.rs @@ -508,7 +508,7 @@ impl SubscriptionPlan { } /// Generate a plan for incrementally maintaining a subscription - pub fn compile(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> Result<(Vec, bool)> { + pub fn compile(sql: &str, tx: &mut impl SchemaView, auth: &AuthCtx) -> Result<(Vec, bool)> { let (plans, return_id, return_name, has_param) = compile_subscription(sql, tx, auth)?; /// Does this plan have any non-index joins?