diff --git a/crates/core/src/sql/execute.rs b/crates/core/src/sql/execute.rs index 36a74c39869..1e90f07abbf 100644 --- a/crates/core/src/sql/execute.rs +++ b/crates/core/src/sql/execute.rs @@ -279,8 +279,12 @@ pub fn run_from_module( return Err(anyhow!("Caller {} is not authorized to run SQL DML statements", auth.caller()).into()); } + stmt.for_each_return_field(|col_name, col_type| { + head.push((col_name.into(), col_type.clone())); + }); + // Evaluate the mutation - let (mut tx, _) = db.with_auto_rollback(tx, |tx| execute_dml_stmt(&auth, stmt, tx, &mut metrics))?; + let (mut tx, rows) = db.with_auto_rollback(tx, |tx| execute_dml_stmt(&auth, stmt, tx, &mut metrics))?; // Update transaction metrics tx.metrics.merge(metrics); @@ -344,7 +348,7 @@ pub fn run_from_module( Ok(( SqlResult { tx_offset: res.tx_offset, - rows: vec![], + rows, metrics, }, trapped, diff --git a/crates/execution/src/dml.rs b/crates/execution/src/dml.rs index baa78e1e0e6..ec4e9fb535c 100644 --- a/crates/execution/src/dml.rs +++ b/crates/execution/src/dml.rs @@ -1,10 +1,10 @@ use anyhow::Result; use spacetimedb_lib::{metrics::ExecutionMetrics, AlgebraicValue, ProductValue}; -use spacetimedb_physical_plan::dml::{DeletePlan, InsertPlan, MutationPlan, UpdatePlan}; +use spacetimedb_physical_plan::{dml::{DeletePlan, InsertPlan, MutationPlan, UpdatePlan}, plan::{ProjectField, ProjectListPlan}}; use spacetimedb_primitives::{ColId, TableId}; use spacetimedb_sats::size_of::SizeOf; -use crate::{pipelined::PipelinedProject, Datastore, DeltaStore}; +use crate::{pipelined::PipelinedProject, Datastore, DeltaStore, Row}; /// A mutable datastore can read as well as insert and delete rows pub trait MutDatastore: Datastore + DeltaStore { @@ -30,7 +30,7 @@ impl From for MutExecutor { } impl MutExecutor { - pub fn execute(&self, tx: &mut Tx, metrics: &mut ExecutionMetrics) -> Result<()> { + pub fn execute(&self, tx: &mut Tx, metrics: &mut ExecutionMetrics) -> Result> { match self { Self::Insert(exec) => exec.execute(tx, metrics), Self::Delete(exec) => exec.execute(tx, metrics), @@ -39,32 +39,51 @@ impl MutExecutor { } } +fn project_returning_row(returning: &ProjectListPlan, row: ProductValue) -> Option { + match returning { + ProjectListPlan::Name(_) => { + Some(row) + } + ProjectListPlan::List(_, fields) => { + let row = Row::Ref(&row); + Some(ProductValue::from_iter(fields.iter().map(|field| row.project(field)))) + } + _ => None + } +} + /// Executes row insertions pub struct InsertExecutor { table_id: TableId, rows: Vec, + returning: Option, } impl From for InsertExecutor { fn from(plan: InsertPlan) -> Self { Self { - rows: plan.rows, table_id: plan.table.table_id, + rows: plan.rows, + returning: plan.returning, } } } impl InsertExecutor { - fn execute(&self, tx: &mut Tx, metrics: &mut ExecutionMetrics) -> Result<()> { + fn execute(&self, tx: &mut Tx, metrics: &mut ExecutionMetrics) -> Result> { + let mut results = vec![]; for row in &self.rows { if tx.insert_product_value(self.table_id, row)? { metrics.rows_inserted += 1; + if let Some(returning) = &self.returning { + project_returning_row(returning, row.clone()).map(|res| results.push(res)); + } } } // TODO: It would be better to get this metric from the bsatn buffer. // But we haven't been concerned with optimizing DML up to this point. metrics.bytes_written += self.rows.iter().map(|row| row.size_of()).sum::(); - Ok(()) + Ok(results) } } @@ -72,6 +91,7 @@ impl InsertExecutor { pub struct DeleteExecutor { table_id: TableId, filter: PipelinedProject, + returning: Option, } impl From for DeleteExecutor { @@ -79,12 +99,13 @@ impl From for DeleteExecutor { Self { table_id: plan.table.table_id, filter: plan.filter.into(), + returning: plan.returning, } } } impl DeleteExecutor { - fn execute(&self, tx: &mut Tx, metrics: &mut ExecutionMetrics) -> Result<()> { + fn execute(&self, tx: &mut Tx, metrics: &mut ExecutionMetrics) -> Result> { // TODO: Delete by row id instead of product value let mut deletes = vec![]; self.filter.execute(tx, metrics, &mut |row| { @@ -95,12 +116,16 @@ impl DeleteExecutor { // Note, that we don't update bytes written, // because deletes don't actually write out any bytes. metrics.bytes_scanned += deletes.iter().map(|row| row.size_of()).sum::(); + let mut results = vec![]; for row in &deletes { if tx.delete_product_value(self.table_id, row)? { metrics.rows_deleted += 1; + if let Some(returning) = &self.returning { + project_returning_row(returning, row.clone()).map(|res| results.push(res)); + } } } - Ok(()) + Ok(results) } } @@ -109,6 +134,7 @@ pub struct UpdateExecutor { table_id: TableId, columns: Vec<(ColId, AlgebraicValue)>, filter: PipelinedProject, + returning: Option, } impl From for UpdateExecutor { @@ -117,12 +143,13 @@ impl From for UpdateExecutor { columns: plan.columns, table_id: plan.table.table_id, filter: plan.filter.into(), + returning: plan.returning, } } } impl UpdateExecutor { - fn execute(&self, tx: &mut Tx, metrics: &mut ExecutionMetrics) -> Result<()> { + fn execute(&self, tx: &mut Tx, metrics: &mut ExecutionMetrics) -> Result> { let mut deletes = vec![]; self.filter.execute(tx, metrics, &mut |row| { deletes.push(row.to_product_value()); @@ -134,6 +161,7 @@ impl UpdateExecutor { // TODO: This metric should be updated inline when we serialize. metrics.bytes_scanned = deletes.iter().map(|row| row.size_of()).sum::(); metrics.rows_updated += deletes.len() as u64; + let mut results = vec![]; for row in &deletes { let row = ProductValue::from_iter( row @@ -151,7 +179,10 @@ impl UpdateExecutor { ); tx.insert_product_value(self.table_id, &row)?; metrics.bytes_written += row.size_of(); + if let Some(returning) = &self.returning { + project_returning_row(returning, row).map(|res| results.push(res)); + } } - Ok(()) + Ok(results) } } diff --git a/crates/expr/src/statement.rs b/crates/expr/src/statement.rs index 46db2b02053..16dee54fd95 100644 --- a/crates/expr/src/statement.rs +++ b/crates/expr/src/statement.rs @@ -56,22 +56,38 @@ impl DML { pub fn table_name(&self) -> Box { self.table_schema().table_name.clone() } + + /// Iterate over the projected column names and types + pub fn for_each_return_field(&self, f: impl FnMut(&str, &AlgebraicType)) { + match self { + Self::Insert(TableInsert { returning, .. }) + | Self::Update(TableUpdate { returning, .. }) + | Self::Delete(TableDelete { returning, .. }) => { + if let Some(returning) = returning { + returning.for_each_return_field(f); + } + } + } + } } pub struct TableInsert { pub table: Arc, pub rows: Box<[ProductValue]>, + pub returning: Option, } pub struct TableDelete { pub table: Arc, pub filter: Option, + pub returning: Option, } pub struct TableUpdate { pub table: Arc, pub columns: Box<[(ColId, AlgebraicValue)]>, pub filter: Option, + pub returning: Option, } pub struct SetVar { @@ -89,6 +105,7 @@ pub fn type_insert(insert: SqlInsert, tx: &impl SchemaView) -> TypingResult TypingResult TypingResult TypingResult TypingResult TypingResult TypingResult, pub rows: Vec, + pub returning: Option, } impl From for InsertPlan { fn from(insert: TableInsert) -> Self { - let TableInsert { table, rows } = insert; + let TableInsert { table, rows, returning } = insert; let rows = rows.into_vec(); - Self { table, rows } + let returning = returning.map(compile_select_list); + Self { table, rows, returning } } } @@ -47,19 +49,20 @@ impl From for InsertPlan { pub struct DeletePlan { pub table: Arc, pub filter: ProjectPlan, + pub returning: Option, } impl DeletePlan { /// Optimize the filter part of the delete fn optimize(self, auth: &AuthCtx) -> Result { - let Self { table, filter } = self; + let Self { table, filter, returning } = self; let filter = filter.optimize(auth)?; - Ok(Self { table, filter }) + Ok(Self { table, filter, returning }) } /// Logical to physical conversion pub(crate) fn compile(delete: TableDelete) -> Self { - let TableDelete { table, filter } = delete; + let TableDelete { table, filter, returning } = delete; let schema = table.clone(); let alias = table.table_name.clone(); let relvar = RelExpr::RelVar(Relvar { @@ -72,7 +75,8 @@ impl DeletePlan { Some(expr) => ProjectName::None(RelExpr::Select(Box::new(relvar), expr)), }; let filter = compile_select(project); - Self { table, filter } + let returning = returning.map(compile_select_list); + Self { table, filter, returning } } } @@ -81,19 +85,20 @@ pub struct UpdatePlan { pub table: Arc, pub columns: Vec<(ColId, AlgebraicValue)>, pub filter: ProjectPlan, + pub returning: Option, } impl UpdatePlan { /// Optimize the filter part of the update fn optimize(self, auth: &AuthCtx) -> Result { - let Self { table, columns, filter } = self; + let Self { table, columns, filter, returning } = self; let filter = filter.optimize(auth)?; - Ok(Self { columns, table, filter }) + Ok(Self { columns, table, filter, returning }) } /// Logical to physical conversion pub(crate) fn compile(update: TableUpdate) -> Self { - let TableUpdate { table, columns, filter } = update; + let TableUpdate { table, columns, filter, returning } = update; let schema = table.clone(); let alias = table.table_name.clone(); let relvar = RelExpr::RelVar(Relvar { @@ -106,7 +111,8 @@ impl UpdatePlan { Some(expr) => ProjectName::None(RelExpr::Select(Box::new(relvar), expr)), }; let filter = compile_select(project); + let returning = returning.map(compile_select_list); let columns = columns.into_vec(); - Self { columns, table, filter } + Self { columns, table, filter, returning } } } diff --git a/crates/query/src/lib.rs b/crates/query/src/lib.rs index d75f2516a2c..92e75160033 100644 --- a/crates/query/src/lib.rs +++ b/crates/query/src/lib.rs @@ -92,7 +92,7 @@ pub fn execute_dml_stmt( stmt: DML, tx: &mut Tx, metrics: &mut ExecutionMetrics, -) -> Result<()> { +) -> Result> { let plan = compile_dml_plan(stmt).optimize(auth)?; let plan = MutExecutor::from(plan); plan.execute(tx, metrics) diff --git a/crates/sql-parser/src/ast/sql.rs b/crates/sql-parser/src/ast/sql.rs index 567b5ec5328..684ed3fef92 100644 --- a/crates/sql-parser/src/ast/sql.rs +++ b/crates/sql-parser/src/ast/sql.rs @@ -25,18 +25,32 @@ impl SqlAst { pub fn qualify_vars(self) -> Self { match self { Self::Select(select) => Self::Select(select.qualify_vars()), + Self::Insert(SqlInsert { + table: with, + fields, + values, + returning, + }) => Self::Insert(SqlInsert { + table: with.clone(), + fields, + values, + returning: returning.map(|proj| proj.qualify_vars(with)), + }), Self::Update(SqlUpdate { table: with, assignments, filter, + returning, }) => Self::Update(SqlUpdate { table: with.clone(), - filter: filter.map(|expr| expr.qualify_vars(with)), + filter: filter.map(|expr| expr.qualify_vars(with.clone())), assignments, + returning: returning.map(|proj| proj.qualify_vars(with)), }), - Self::Delete(SqlDelete { table: with, filter }) => Self::Delete(SqlDelete { + Self::Delete(SqlDelete { table: with, filter, returning }) => Self::Delete(SqlDelete { table: with.clone(), - filter: filter.map(|expr| expr.qualify_vars(with)), + filter: filter.map(|expr| expr.qualify_vars(with.clone())), + returning: returning.map(|proj| proj.qualify_vars(with)), }), _ => self, } @@ -111,6 +125,7 @@ pub struct SqlInsert { pub table: SqlIdent, pub fields: Vec, pub values: SqlValues, + pub returning: Option, } /// VALUES literals @@ -123,6 +138,7 @@ pub struct SqlUpdate { pub table: SqlIdent, pub assignments: Vec, pub filter: Option, + pub returning: Option, } impl SqlUpdate { @@ -140,6 +156,7 @@ impl SqlUpdate { pub struct SqlDelete { pub table: SqlIdent, pub filter: Option, + pub returning: Option, } impl SqlDelete { diff --git a/crates/sql-parser/src/parser/sql.rs b/crates/sql-parser/src/parser/sql.rs index a1eb5078726..668eed2c941 100644 --- a/crates/sql-parser/src/parser/sql.rs +++ b/crates/sql-parser/src/parser/sql.rs @@ -129,7 +129,7 @@ use sqlparser::{ ast::{ - Assignment, Expr, GroupByExpr, ObjectName, Query, Select, SetExpr, Statement, TableFactor, TableWithJoins, + Assignment, Expr, GroupByExpr, ObjectName, Query, Select, SelectItem, SetExpr, Statement, TableFactor, TableWithJoins, Value, Values, }, dialect::PostgreSqlDialect, @@ -174,12 +174,13 @@ fn parse_statement(stmt: Statement) -> SqlParseResult { after_columns, table: false, on: None, - returning: None, + returning, .. } if after_columns.is_empty() => Ok(SqlAst::Insert(SqlInsert { table: parse_ident(table_name)?, fields: columns.into_iter().map(SqlIdent::from).collect(), values: parse_values(*source)?, + returning: returning.map(parse_projection).transpose()?, })), Statement::Update { table: @@ -198,19 +199,20 @@ fn parse_statement(stmt: Statement) -> SqlParseResult { assignments, from: None, selection, - returning: None, + returning, } if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlAst::Update(SqlUpdate { table: parse_ident(name)?, assignments: parse_assignments(assignments)?, filter: parse_expr_opt(selection)?, + returning: returning.map(parse_projection).transpose()?, })), Statement::Delete { tables, from, using: None, selection, - returning: None, - } if tables.is_empty() => Ok(SqlAst::Delete(parse_delete(from, selection)?)), + returning, + } if tables.is_empty() => Ok(SqlAst::Delete(parse_delete(from, selection, returning)?)), Statement::SetVariable { local: false, hivevar: false, @@ -281,7 +283,7 @@ fn parse_assignment(Assignment { id, value }: Assignment) -> SqlParseResult, selection: Option) -> SqlParseResult { +fn parse_delete(mut from: Vec, selection: Option, returning: Option>) -> SqlParseResult { if from.len() == 1 { match from.swap_remove(0) { TableWithJoins { @@ -298,6 +300,7 @@ fn parse_delete(mut from: Vec, selection: Option) -> SqlPa } if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlDelete { table: parse_ident(name)?, filter: parse_expr_opt(selection)?, + returning: returning.map(parse_projection).transpose()?, }), t => Err(SqlUnsupported::DeleteTable(t).into()), }