diff --git a/crates/cli/src/subcommands/sql.rs b/crates/cli/src/subcommands/sql.rs index 00a8d645e42..a93d16a6df6 100644 --- a/crates/cli/src/subcommands/sql.rs +++ b/crates/cli/src/subcommands/sql.rs @@ -239,7 +239,7 @@ mod tests { use spacetimedb_lib::error::ResultTest; use spacetimedb_lib::sats::time_duration::TimeDuration; use spacetimedb_lib::sats::timestamp::Timestamp; - use spacetimedb_lib::sats::{product, GroundSpacetimeType, ProductType}; + use spacetimedb_lib::sats::{product, ArrayValue, GroundSpacetimeType, ProductType}; use spacetimedb_lib::{AlgebraicType, AlgebraicValue, ConnectionId, Identity}; fn make_row(row: &[AlgebraicValue]) -> Result, serde_json::Error> { @@ -594,4 +594,124 @@ Roundtrip time: 1.00ms"#, Ok(()) } + + #[test] + fn output_arrays() -> ResultTest<()> { + let kind: ProductType = [("arr", AlgebraicType::array(AlgebraicType::I32))].into(); + let value = product![AlgebraicValue::Array([1, 2, 3].into()).clone()]; + + expect_psql_table( + PsqlClient::SpacetimeDB, + &kind, + vec![value.clone()], + r#" + arr +----------- + [1, 2, 3]"#, + ); + + expect_psql_table( + PsqlClient::Postgres, + &kind, + vec![value.clone()], + r#" + arr +----------- + {1, 2, 3}"#, + ); + + let kind: ProductType = [("arr", AlgebraicType::array(AlgebraicType::array(AlgebraicType::I32)))].into(); + + let arr = ArrayValue::I32([1, 2, 3].into()); + let value = product![AlgebraicValue::Array(ArrayValue::Array([arr.clone(), arr].into()))]; + expect_psql_table( + PsqlClient::SpacetimeDB, + &kind, + vec![value.clone()], + r#" + arr +------------------------ + [[1, 2, 3], [1, 2, 3]]"#, + ); + + expect_psql_table( + PsqlClient::Postgres, + &kind, + vec![value.clone()], + r#" + arr +------------------------ + {{1, 2, 3}, {1, 2, 3}}"#, + ); + + // Check struct + let kind: ProductType = [( + "arr", + AlgebraicType::array(AlgebraicType::product([ + ("a", AlgebraicType::I32), + ("b", AlgebraicType::String), + ])), + )] + .into(); + + let value = product![AlgebraicValue::Array(ArrayValue::Product( + [ + product![AlgebraicValue::I32(1), AlgebraicValue::String("one".into())], + product![AlgebraicValue::I32(2), AlgebraicValue::String("two".into())], + ] + .into() + ))]; + + expect_psql_table( + PsqlClient::SpacetimeDB, + &kind, + vec![value.clone()], + r#" + arr +------------------------------------------ + [(a = 1, b = "one"), (a = 2, b = "two")]"#, + ); + expect_psql_table( + PsqlClient::Postgres, + &kind, + vec![value], + r#" + arr +---------------------------------------------- + {{"a": 1, "b": "one"}, {"a": 2, "b": "two"}}"#, + ); + + // Check struct with array fields + let kind: ProductType = [ + ("a", AlgebraicType::array(AlgebraicType::I32)), + ("b", AlgebraicType::array(AlgebraicType::String)), + ] + .into(); + let value = product![ + AlgebraicValue::Array(ArrayValue::I32([1, 2, 3].into())), + AlgebraicValue::Array(ArrayValue::String(["one".into(), "two".into(), "three".into()].into())) + ]; + + expect_psql_table( + PsqlClient::SpacetimeDB, + &kind, + vec![value.clone()], + r#" + a | b +-----------+------------------------- + [1, 2, 3] | ["one", "two", "three"]"#, + ); + + expect_psql_table( + PsqlClient::Postgres, + &kind, + vec![value], + r#" + a | b +-----------+------------------------- + {1, 2, 3} | {"one", "two", "three"}"#, + ); + + Ok(()) + } } diff --git a/crates/pg/src/encoder.rs b/crates/pg/src/encoder.rs index f5a6ed990ed..997d1246bc0 100644 --- a/crates/pg/src/encoder.rs +++ b/crates/pg/src/encoder.rs @@ -2,8 +2,8 @@ use crate::pg_server::PgError; use pgwire::api::portal::Format; use pgwire::api::results::{DataRowEncoder, FieldInfo}; use pgwire::api::Type; -use spacetimedb_lib::sats::satn::{PsqlChars, PsqlPrintFmt, PsqlType, TypedWriter}; -use spacetimedb_lib::sats::{satn, ValueWithType}; +use spacetimedb_lib::sats::satn::{PsqlChars, PsqlClient, PsqlPrintFmt, PsqlType, TypedWriter}; +use spacetimedb_lib::sats::{satn, ArrayValue, ValueWithType}; use spacetimedb_lib::{ ser, AlgebraicType, AlgebraicValue, ProductType, ProductTypeElement, ProductValue, TimeDuration, Timestamp, }; @@ -50,7 +50,11 @@ pub(crate) fn type_of(schema: &ProductType, ty: &ProductTypeElement) -> Type { | AlgebraicType::U128 | AlgebraicType::I256 | AlgebraicType::U256 => Type::NUMERIC_ARRAY, - _ => Type::ANYARRAY, + AlgebraicType::F32 => Type::FLOAT4_ARRAY, + AlgebraicType::F64 => Type::FLOAT8_ARRAY, + AlgebraicType::Ref(_) | AlgebraicType::Sum(_) | AlgebraicType::Product(_) | AlgebraicType::Array(_) => { + Type::JSON_ARRAY + } }, AlgebraicType::Product(_) => match format { PsqlPrintFmt::Hex => Type::BYTEA_ARRAY, @@ -74,6 +78,39 @@ pub(crate) struct PsqlFormatter<'a> { pub(crate) encoder: &'a mut DataRowEncoder, } +impl<'a> PsqlFormatter<'a> { + fn encode_variant(tag: u8, ty: PsqlType, name: Option<&str>, value: ValueWithType) -> String { + // Is a simple enum? + if let AlgebraicType::Sum(sum) = &ty.field.algebraic_type { + if sum.is_simple_enum() { + if let Some(variant_name) = name { + return variant_name.to_string(); + } + } + } + + if ty.field.algebraic_type.is_unit() { + if let Some(variant_name) = name { + return variant_name.to_string(); + } + } + + let PsqlChars { + start, + sep, + end, + quote, + start_array: _, + end_array: _, + } = ty.client.format_chars(); + let name = name.map(Cow::from).unwrap_or_else(|| Cow::from(tag.to_string())); + format!( + "{start}{quote}{name}{quote}{sep} {}{end}", + satn::PsqlWrapper { ty, value } + ) + } +} + impl TypedWriter for PsqlFormatter<'_> { type Error = PgError; @@ -146,7 +183,14 @@ impl TypedWriter for PsqlFormatter<'_> { } } - let PsqlChars { start, sep, end, quote } = ty.client.format_chars(); + let PsqlChars { + start, + sep, + end, + quote, + start_array: _, + end_array: _, + } = ty.client.format_chars(); let name = name.map(Cow::from).unwrap_or_else(|| Cow::from(tag.to_string())); let json = format!( "{start}{quote}{name}{quote}{sep} {}{end}", @@ -155,6 +199,129 @@ impl TypedWriter for PsqlFormatter<'_> { self.encoder.encode_field(&json)?; Ok(()) } + + fn write_array( + &mut self, + value: &ValueWithType<'_, ArrayValue>, + psql: &PsqlType, + ty: &AlgebraicType, + ) -> Result { + if *ty == AlgebraicType::U8 { + return Ok(false); + } + fn collect(arr: &[I], map: F) -> Vec + where + I: Clone, + F: Fn(usize, &I) -> O, + { + arr.iter().enumerate().map(|(pos, v)| map(pos, v)).collect() + } + let ty = &value.ty().elem_ty; + let type_space = &value.typespace(); + match value.value() { + ArrayValue::Bool(arr) => self.encoder.encode_field(&arr.as_ref())?, + ArrayValue::I8(arr) => self.encoder.encode_field(&arr.as_ref())?, + ArrayValue::U8(arr) => self.encoder.encode_field(&arr.as_ref())?, + ArrayValue::I16(arr) => self.encoder.encode_field(&arr.as_ref())?, + ArrayValue::U16(arr) => self.encoder.encode_field(&collect(arr, |_, v| *v as i32))?, + ArrayValue::I32(arr) => self.encoder.encode_field(&arr.as_ref())?, + ArrayValue::U32(arr) => self.encoder.encode_field(&collect(arr, |_, v| *v as i64))?, + ArrayValue::I64(arr) => self.encoder.encode_field(&arr.as_ref())?, + ArrayValue::U64(arr) => self.encoder.encode_field(&collect(arr, |_, v| v.to_string()))?, + ArrayValue::I128(arr) => self.encoder.encode_field(&collect(arr, |_, v| v.to_string()))?, + ArrayValue::U128(arr) => self.encoder.encode_field(&collect(arr, |_, v| v.to_string()))?, + ArrayValue::I256(arr) => self.encoder.encode_field(&collect(arr, |_, v| v.to_string()))?, + ArrayValue::U256(arr) => self.encoder.encode_field(&collect(arr, |_, v| v.to_string()))?, + ArrayValue::F32(arr) => self.encoder.encode_field(&collect(arr, |_, v| *v.as_ref()))?, + ArrayValue::F64(arr) => self.encoder.encode_field(&collect(arr, |_, v| *v.as_ref()))?, + ArrayValue::String(arr) => self.encoder.encode_field(&collect(arr, |_, v| v.to_string()))?, + ArrayValue::Array(arr) => { + let values = collect(arr, |_pos, val| { + let mut psql = psql.clone(); + // Switching client because we are outputting nested arrays as JSON + psql.client = PsqlClient::SpacetimeDB; + satn::PsqlWrapper { + ty: psql, + value: val.clone(), + } + .to_string() + }); + self.encoder.encode_field(&values)?; + } + ArrayValue::Sum(sum) => { + let values = collect(sum, |_pos, val| { + let (tag, value) = match &**ty { + AlgebraicType::Sum(sum) => { + let field = sum.variants.get(val.tag as usize).expect("Invalid variant tag"); + (field, val.value.clone()) + } + _ => unreachable!("Expected sum type"), + }; + let field = ProductTypeElement::new(tag.algebraic_type.clone(), tag.name.clone()); + + PsqlFormatter::encode_variant( + val.tag, + PsqlType { + client: psql.client, + field: &field.clone(), + tuple: &ProductType::new([field].into()), + idx: 0, + }, + tag.name.as_deref(), + ValueWithType::new(type_space.with_type(&tag.algebraic_type), &value), + ) + }); + self.encoder.encode_field(&values)?; + } + ArrayValue::Product(value) => { + let PsqlChars { + start, + sep, + end, + quote, + start_array: _, + end_array: _, + } = psql.client.format_chars(); + let values = collect(value, |pos, value| { + let json = match &**ty { + AlgebraicType::Product(prod) => { + let mut json = String::new(); + for (field, value) in prod.elements.iter().zip(value.elements.iter()) { + let psql_ty = PsqlType { + client: psql.client, + field, + tuple: prod, + idx: pos, + }; + if !json.is_empty() { + json.push(','); + } + let name = field + .name + .as_deref() + .map(Cow::from) + .unwrap_or_else(|| Cow::from(pos.to_string())); + let field_json = + format!("{quote}{name}{quote}{sep} {}", satn::PsqlWrapper { ty: psql_ty, value }); + json.push_str(&field_json); + } + json + } + _ => unreachable!("Expected product type"), + }; + format!("{start}{}{end}", json) + }); + + self.encoder.encode_field(&values)?; + } + } + + Ok(true) + } + + fn insert_sep(&mut self, _sep: &str) -> Result<(), Self::Error> { + Ok(()) // No-op for PSQL format + } } #[cfg(test)] @@ -164,7 +331,9 @@ mod tests { use futures::StreamExt; use spacetimedb_client_api_messages::http::SqlStmtResult; use spacetimedb_lib::sats::algebraic_value::Packed; - use spacetimedb_lib::sats::{i256, product, u256, AlgebraicType, ProductType, SumTypeVariant}; + use spacetimedb_lib::sats::{ + i256, product, u256, AlgebraicType, ArrayValue, ProductType, SumTypeVariant, SumValue, + }; use spacetimedb_lib::{ConnectionId, Identity}; async fn run(schema: ProductType, row: ProductValue) -> String { @@ -298,4 +467,150 @@ mod tests { let row = run(schema, value).await; assert_eq!(row, "\0\0\0B\\x0000000000000000000000000000000000000000000000000000000000000000\0\0\0\"\\x00000000000000000000000000000000\0\0\0\u{3}P0D\0\0\0\u{1d}1970-01-19T18:42:25.800+00:00\0\0\0\n\\x74657374"); } + + #[tokio::test] + async fn test_array() { + // {a: [1,2,3], b: [{"a": 1}, {"b": true}], c: [0xDE, 0xAD, 0xBE, 0xEF]} + let product = AlgebraicType::product([ + ProductTypeElement::new(AlgebraicType::I32, Some("a".into())), + ProductTypeElement::new(AlgebraicType::Bool, Some("b".into())), + ]); + let schema = ProductType::from([ + AlgebraicType::array(AlgebraicType::I32), + AlgebraicType::array(product.clone()), + AlgebraicType::bytes(), + ]); + + let value = product![ + AlgebraicValue::Array(ArrayValue::I32([1, 2, 3].into())), + AlgebraicValue::Array(ArrayValue::Product([product![1, true]].into())), + AlgebraicValue::Bytes([0xDE, 0xAD, 0xBE, 0xEF].into()), + ]; + + let row = run(schema.clone(), value.clone()).await; + assert_eq!( + row, + "\0\0\0\u{7}{1,2,3}\0\0\0\u{1a}{\"{\\\"a\\\": 1,\\\"b\\\": true}\"}\0\0\0\n\\xdeadbeef" + ); + + // Check all the unnested arrays are encoded as native PG arrays, and nested arrays, sum & product arrays as JSON + let arrays = vec![ + ( + ArrayValue::Bool([true, false, true].into()), + AlgebraicType::Bool, + "\u{7}{t,f,t}", + ), + (ArrayValue::I8([-1, 0, 1].into()), AlgebraicType::I8, "\u{8}{-1,0,1}"), + (ArrayValue::U8([0, 1, 2].into()), AlgebraicType::U8, "\u{8}\\x000102"), + ( + ArrayValue::I16([-256, 0, 256].into()), + AlgebraicType::I16, + "\u{c}{-256,0,256}", + ), + ( + ArrayValue::U16([0, 256, 65535].into()), + AlgebraicType::U16, + "\r{0,256,65535}", + ), + ( + ArrayValue::I32([-65536, 0, 65536].into()), + AlgebraicType::I32, + "\u{10}{-65536,0,65536}", + ), + ( + ArrayValue::U32([0, 65536, 4294967295].into()), + AlgebraicType::U32, + "\u{14}{0,65536,4294967295}", + ), + ( + ArrayValue::I64([-4294967296, 0, 4294967296].into()), + AlgebraicType::I64, + "\u{1a}{-4294967296,0,4294967296}", + ), + ( + ArrayValue::U64([0, 4294967296, 18446744073709551615].into()), + AlgebraicType::U64, + "#{0,4294967296,18446744073709551615}", + ), + ( + ArrayValue::I128([i128::MIN, 0, i128::MAX].into()), + AlgebraicType::I128, + "T{-170141183460469231731687303715884105728,0,170141183460469231731687303715884105727}", + ), + ( + ArrayValue::U128([0, u128::MAX].into()), + AlgebraicType::U128, + "+{0,340282366920938463463374607431768211455}", + ), + ( + ArrayValue::I256([i256::from(-1), i256::from(0), i256::from(1)].into()), + AlgebraicType::I256, + "\u{8}{-1,0,1}", + ), + ( + ArrayValue::U256([u256::ZERO, u256::ONE].into()), + AlgebraicType::U256, + "\u{5}{0,1}", + ), + ( + ArrayValue::F32([1.5.into(), 2.5.into(), 3.5.into()].into()), + AlgebraicType::F32, + "\r{1.5,2.5,3.5}", + ), + ( + ArrayValue::F64([1.5.into(), 2.5.into(), 3.5.into()].into()), + AlgebraicType::F64, + "\r{1.5,2.5,3.5}", + ), + ( + ArrayValue::String(["foo".into(), "bar".into(), "baz".into()].into()), + AlgebraicType::String, + "\r{foo,bar,baz}", + ), + ( + ArrayValue::Product([product![1], product![2], product![3]].into()), + AlgebraicType::product([ProductTypeElement::new(AlgebraicType::I32, None)]), + "({\"{\\\"0\\\": 1}\",\"{\\\"1\\\": 2}\",\"{\\\"2\\\": 3}\"}", + ), + // Array of arrays + ( + ArrayValue::Array([ArrayValue::I32([1, 2].into()), ArrayValue::I32([3, 4].into())].into()), + AlgebraicType::array(AlgebraicType::I32), + "\u{13}{\"[1, 2]\",\"[3, 4]\"}", + ), + // Simple enum array + ( + ArrayValue::Sum( + [ + SumValue::new_simple(0), + SumValue::new_simple(1), + SumValue::new_simple(2), + ] + .into(), + ), + AlgebraicType::simple_enum(["A", "B", "C"].into_iter()), + "\u{7}{A,B,C}", + ), + // Non-simple enum array + ( + ArrayValue::Sum( + [ + SumValue::new(0, AlgebraicValue::I64(1)), + SumValue::new(1, AlgebraicValue::unit()), + ] + .into(), + ), + AlgebraicType::option(AlgebraicType::I64), + "\u{16}{\"{\\\"some\\\": 1}\",none}", + ), + ]; + + for (array_value, ty, expected_encoding) in arrays { + let schema = ProductType::from([AlgebraicType::array(ty.clone())]); + let value = product![AlgebraicValue::Array(array_value)]; + let row = run(schema, value).await; + let expected_row = format!("\0\0\0{}", expected_encoding); + assert_eq!(row, expected_row, "Failed for array encoding for {ty:?}"); + } + } } diff --git a/crates/sats/src/satn.rs b/crates/sats/src/satn.rs index 2a5f128262e..9173696cea6 100644 --- a/crates/sats/src/satn.rs +++ b/crates/sats/src/satn.rs @@ -1,6 +1,7 @@ +use crate::ser::SerializeArray; use crate::time_duration::TimeDuration; use crate::timestamp::Timestamp; -use crate::{i256, u256, AlgebraicType, AlgebraicValue, ProductValue, Serialize, SumValue, ValueWithType}; +use crate::{i256, u256, AlgebraicType, AlgebraicValue, ArrayValue, ProductValue, Serialize, SumValue, ValueWithType}; use crate::{ser, ProductType, ProductTypeElement}; use core::fmt; use core::fmt::Write as _; @@ -11,8 +12,8 @@ use std::marker::PhantomData; /// An extension trait for [`Serialize`] providing formatting methods. pub trait Satn: ser::Serialize { /// Formats the value using the SATN data format into the formatter `f`. - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - Writer::with(f, |f| self.serialize(SatnFormatter { f }))?; + fn fmt(&self, f: &mut fmt::Formatter, client: PsqlClient) -> fmt::Result { + Writer::with(f, |f| self.serialize(SatnFormatter { f, client }))?; Ok(()) } @@ -22,7 +23,7 @@ pub trait Satn: ser::Serialize { self.serialize(TypedSerializer { ty, f: &mut SqlFormatter { - fmt: SatnFormatter { f }, + fmt: SatnFormatter { f, client: ty.client }, ty, }, }) @@ -60,13 +61,13 @@ impl Wrapper { impl fmt::Display for Wrapper { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) + self.0.fmt(f, PsqlClient::SpacetimeDB) } } impl fmt::Debug for Wrapper { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) + self.0.fmt(f, PsqlClient::SpacetimeDB) } } @@ -231,6 +232,7 @@ impl fmt::Write for Writer<'_, '_> { struct SatnFormatter<'a, 'f> { /// The sink / writer / output / formatter. f: Writer<'a, 'f>, + client: PsqlClient, } impl SatnFormatter<'_, '_> { @@ -240,16 +242,18 @@ impl SatnFormatter<'_, '_> { name: Option<&str>, value: &T, ) -> Result<(), SatnError> { - write!(self, "(")?; + let chars = self.client.format_chars(); + + write!(self, "{}", chars.start)?; EntryWrapper::<','>::new(self.f.as_mut()).entry(|mut f| { if let Some(name) = name { write!(f, "{name}")?; } - write!(f, " = ")?; - value.serialize(SatnFormatter { f })?; + write!(f, " {}", chars.sep)?; + value.serialize(SatnFormatter { f, client: self.client })?; Ok(()) })?; - write!(self, ")")?; + write!(self, "{}", chars.end)?; Ok(()) } @@ -375,7 +379,13 @@ impl ser::SerializeArray for ArrayFormatter<'_, '_> { type Error = SatnError; fn serialize_element(&mut self, elem: &T) -> Result<(), Self::Error> { - self.f.entry(|f| elem.serialize(SatnFormatter { f }).map_err(|e| e.0))?; + self.f.entry(|f| { + elem.serialize(SatnFormatter { + f, + client: PsqlClient::SpacetimeDB, + }) + .map_err(|e| e.0) + })?; Ok(()) } @@ -429,7 +439,10 @@ impl ser::SerializeNamedProduct for NamedFormatter<'_, '_> { write!(f, "{}", self.idx)?; } write!(f, " = ")?; - elem.serialize(SatnFormatter { f })?; + elem.serialize(SatnFormatter { + f, + client: PsqlClient::SpacetimeDB, + })?; Ok(()) }); self.idx += 1; @@ -452,8 +465,10 @@ pub enum PsqlClient { pub struct PsqlChars { pub start: char, + pub start_array: &'static str, pub sep: &'static str, pub end: char, + pub end_array: &'static str, pub quote: &'static str, } @@ -462,14 +477,18 @@ impl PsqlClient { match self { PsqlClient::SpacetimeDB => PsqlChars { start: '(', + start_array: "[", sep: " =", end: ')', + end_array: "]", quote: "", }, PsqlClient::Postgres => PsqlChars { start: '{', + start_array: "{", sep: ":", end: '}', + end_array: "}", quote: "\"", }, } @@ -592,12 +611,22 @@ pub trait TypedWriter { name: Option<&str>, value: ValueWithType, ) -> Result<(), Self::Error>; + + fn write_array( + &mut self, + value: &ValueWithType<'_, ArrayValue>, + psql: &PsqlType, + ty: &AlgebraicType, + ) -> Result; + + fn insert_sep(&mut self, sep: &str) -> Result<(), Self::Error>; } /// A formatter for arrays that uses the `TypedWriter` trait to write elements. pub struct TypedArrayFormatter<'a, 'f, F> { ty: &'a PsqlType<'a>, f: &'f mut F, + first: bool, } impl ser::SerializeArray for TypedArrayFormatter<'_, '_, F> { @@ -605,11 +634,17 @@ impl ser::SerializeArray for TypedArrayFormatter<'_, '_, F> { type Error = F::Error; fn serialize_element(&mut self, elem: &T) -> Result<(), Self::Error> { + if !self.first { + self.f.insert_sep(", ")?; + } else { + self.first = false; + } elem.serialize(TypedSerializer { ty: self.ty, f: self.f })?; Ok(()) } fn end(self) -> Result { + self.f.insert_sep(self.ty.client.format_chars().end_array)?; // Closed via `.end()`. Ok(()) } } @@ -750,9 +785,13 @@ impl<'a, 'f, F: TypedWriter> ser::Serializer for TypedSerializer<'a, 'f, F> { self.f.write_bytes(v) } } - fn serialize_array(self, _len: usize) -> Result { - Ok(TypedArrayFormatter { ty: self.ty, f: self.f }) + self.f.insert_sep(self.ty.client.format_chars().start_array)?; // Closed via `.end()`. + Ok(TypedArrayFormatter { + ty: self.ty, + f: self.f, + first: true, + }) } fn serialize_seq_product(self, _len: usize) -> Result { @@ -833,6 +872,43 @@ impl<'a, 'f, F: TypedWriter> ser::Serializer for TypedSerializer<'a, 'f, F> { ) -> Result { unreachable!("Use `serialize_variant_raw` instead."); } + + fn serialize_array_raw(self, value: &ValueWithType<'_, ArrayValue>) -> Result { + let mut ty = &*value.ty().elem_ty; + if self.f.write_array(value, self.ty, ty)? { + return Ok(()); + } + loop { + // We're doing this because of `Ref`s. + break match (value.value(), ty) { + (_, &AlgebraicType::Ref(r)) => { + ty = &value.typespace()[r]; + continue; + } + (ArrayValue::Sum(v), AlgebraicType::Sum(ty)) => value.with(ty, v).serialize(self), + (ArrayValue::Product(v), AlgebraicType::Product(ty)) => value.with(ty, v).serialize(self), + (ArrayValue::Bool(v), AlgebraicType::Bool) => v.serialize(self), + (ArrayValue::I8(v), AlgebraicType::I8) => v.serialize(self), + (ArrayValue::U8(v), AlgebraicType::U8) => v.serialize(self), + (ArrayValue::I16(v), AlgebraicType::I16) => v.serialize(self), + (ArrayValue::U16(v), AlgebraicType::U16) => v.serialize(self), + (ArrayValue::I32(v), AlgebraicType::I32) => v.serialize(self), + (ArrayValue::U32(v), AlgebraicType::U32) => v.serialize(self), + (ArrayValue::I64(v), AlgebraicType::I64) => v.serialize(self), + (ArrayValue::U64(v), AlgebraicType::U64) => v.serialize(self), + (ArrayValue::I128(v), AlgebraicType::I128) => v.serialize(self), + (ArrayValue::U128(v), AlgebraicType::U128) => v.serialize(self), + (ArrayValue::I256(v), AlgebraicType::I256) => v.serialize(self), + (ArrayValue::U256(v), AlgebraicType::U256) => v.serialize(self), + (ArrayValue::F32(v), AlgebraicType::F32) => v.serialize(self), + (ArrayValue::F64(v), AlgebraicType::F64) => v.serialize(self), + (ArrayValue::String(v), AlgebraicType::String) => v.serialize(self), + (ArrayValue::Array(v), AlgebraicType::Array(ty)) => value.with(ty, v).serialize(self), + (val, _) if val.is_empty() => self.serialize_array(0)?.end(), + (val, ty) => panic!("mismatched value and schema: {val:?} {ty:?}"), + }; + } + } } impl TypedWriter for SqlFormatter<'_, '_> { @@ -879,7 +955,14 @@ impl TypedWriter for SqlFormatter<'_, '_> { &mut self, fields: Vec<(Cow, PsqlType<'_>, ValueWithType)>, ) -> Result<(), Self::Error> { - let PsqlChars { start, sep, end, quote } = self.ty.client.format_chars(); + let PsqlChars { + start, + start_array: _, + sep, + end, + end_array: _, + quote, + } = self.ty.client.format_chars(); write!(self.fmt, "{start}")?; for (idx, (name, ty, value)) in fields.into_iter().enumerate() { if idx > 0 { @@ -907,4 +990,17 @@ impl TypedWriter for SqlFormatter<'_, '_> { value, )]) } + + fn write_array( + &mut self, + _value: &ValueWithType<'_, ArrayValue>, + _psql: &PsqlType, + _ty: &AlgebraicType, + ) -> Result { + Ok(false) + } + + fn insert_sep(&mut self, sep: &str) -> Result<(), Self::Error> { + write!(self.fmt, "{sep}") + } } diff --git a/crates/sats/src/ser.rs b/crates/sats/src/ser.rs index 8236fbe9c3a..7548d078e60 100644 --- a/crates/sats/src/ser.rs +++ b/crates/sats/src/ser.rs @@ -6,7 +6,10 @@ mod impls; pub mod serde; use crate::de::DeserializeSeed; -use crate::{algebraic_value::ser::ValueSerializer, bsatn, buffer::BufWriter, ProductValue, SumValue, ValueWithType}; +use crate::{ + algebraic_value::ser::ValueSerializer, bsatn, buffer::BufWriter, AlgebraicType, ArrayValue, ProductValue, SumValue, + ValueWithType, +}; use crate::{AlgebraicValue, WithTypespace}; use core::marker::PhantomData; use core::{convert::Infallible, fmt}; @@ -150,6 +153,40 @@ pub trait Serializer: Sized { value: &T, ) -> Result; + fn serialize_array_raw(self, value: &ValueWithType<'_, ArrayValue>) -> Result { + let mut ty = &*value.ty().elem_ty; + loop { + // We're doing this because of `Ref`s. + break match (value.value(), ty) { + (_, &AlgebraicType::Ref(r)) => { + ty = &value.typespace()[r]; + continue; + } + (ArrayValue::Sum(v), AlgebraicType::Sum(ty)) => value.with(ty, v).serialize(self), + (ArrayValue::Product(v), AlgebraicType::Product(ty)) => value.with(ty, v).serialize(self), + (ArrayValue::Bool(v), AlgebraicType::Bool) => v.serialize(self), + (ArrayValue::I8(v), AlgebraicType::I8) => v.serialize(self), + (ArrayValue::U8(v), AlgebraicType::U8) => v.serialize(self), + (ArrayValue::I16(v), AlgebraicType::I16) => v.serialize(self), + (ArrayValue::U16(v), AlgebraicType::U16) => v.serialize(self), + (ArrayValue::I32(v), AlgebraicType::I32) => v.serialize(self), + (ArrayValue::U32(v), AlgebraicType::U32) => v.serialize(self), + (ArrayValue::I64(v), AlgebraicType::I64) => v.serialize(self), + (ArrayValue::U64(v), AlgebraicType::U64) => v.serialize(self), + (ArrayValue::I128(v), AlgebraicType::I128) => v.serialize(self), + (ArrayValue::U128(v), AlgebraicType::U128) => v.serialize(self), + (ArrayValue::I256(v), AlgebraicType::I256) => v.serialize(self), + (ArrayValue::U256(v), AlgebraicType::U256) => v.serialize(self), + (ArrayValue::F32(v), AlgebraicType::F32) => v.serialize(self), + (ArrayValue::F64(v), AlgebraicType::F64) => v.serialize(self), + (ArrayValue::String(v), AlgebraicType::String) => v.serialize(self), + (ArrayValue::Array(v), AlgebraicType::Array(ty)) => value.with(ty, v).serialize(self), + (val, _) if val.is_empty() => self.serialize_array(0)?.end(), + (val, ty) => panic!("mismatched value and schema: {val:?} {ty:?}"), + }; + } + } + /// Serialize the given `bsatn` encoded data of type `ty`. /// /// This is a concession to performance, diff --git a/crates/sats/src/ser/impls.rs b/crates/sats/src/ser/impls.rs index 9baac393dff..c5adced7bae 100644 --- a/crates/sats/src/ser/impls.rs +++ b/crates/sats/src/ser/impls.rs @@ -225,36 +225,7 @@ impl_serialize!([] ValueWithType<'_, ProductValue>, (self, ser) => { ser.serialize_named_product_raw(self) }); impl_serialize!([] ValueWithType<'_, ArrayValue>, (self, ser) => { - let mut ty = &*self.ty().elem_ty; - loop { // We're doing this because of `Ref`s. - break match (self.value(), ty) { - (_, &AlgebraicType::Ref(r)) => { - ty = &self.typespace()[r]; - continue; - } - (ArrayValue::Sum(v), AlgebraicType::Sum(ty)) => self.with(ty, v).serialize(ser), - (ArrayValue::Product(v), AlgebraicType::Product(ty)) => self.with(ty, v).serialize(ser), - (ArrayValue::Bool(v), AlgebraicType::Bool) => v.serialize(ser), - (ArrayValue::I8(v), AlgebraicType::I8) => v.serialize(ser), - (ArrayValue::U8(v), AlgebraicType::U8) => v.serialize(ser), - (ArrayValue::I16(v), AlgebraicType::I16) => v.serialize(ser), - (ArrayValue::U16(v), AlgebraicType::U16) => v.serialize(ser), - (ArrayValue::I32(v), AlgebraicType::I32) => v.serialize(ser), - (ArrayValue::U32(v), AlgebraicType::U32) => v.serialize(ser), - (ArrayValue::I64(v), AlgebraicType::I64) => v.serialize(ser), - (ArrayValue::U64(v), AlgebraicType::U64) => v.serialize(ser), - (ArrayValue::I128(v), AlgebraicType::I128) => v.serialize(ser), - (ArrayValue::U128(v), AlgebraicType::U128) => v.serialize(ser), - (ArrayValue::I256(v), AlgebraicType::I256) => v.serialize(ser), - (ArrayValue::U256(v), AlgebraicType::U256) => v.serialize(ser), - (ArrayValue::F32(v), AlgebraicType::F32) => v.serialize(ser), - (ArrayValue::F64(v), AlgebraicType::F64) => v.serialize(ser), - (ArrayValue::String(v), AlgebraicType::String) => v.serialize(ser), - (ArrayValue::Array(v), AlgebraicType::Array(ty)) => self.with(ty, v).serialize(ser), - (val, _) if val.is_empty() => ser.serialize_array(0)?.end(), - (val, ty) => panic!("mismatched value and schema: {val:?} {ty:?}"), - } - } + ser.serialize_array_raw(self) }); impl_serialize!([] spacetimedb_primitives::ArgId, (self, ser) => ser.serialize_u64(self.0)); diff --git a/smoketests/tests/pg_wire.py b/smoketests/tests/pg_wire.py index 3e86d619473..86460587d76 100644 --- a/smoketests/tests/pg_wire.py +++ b/smoketests/tests/pg_wire.py @@ -92,6 +92,21 @@ class SqlFormat(Smoketest): ints: TInts, } +#[spacetimedb::table(name = t_player)] +pub struct TPlayer { + id: u32, + name: String, +} + +#[spacetimedb::table(name = t_arrays)] +pub struct TArrays { + pos: Vec, + velocity: Vec, + colors: Vec>, + colors_2: Vec>, + players: Vec, +} + #[spacetimedb::reducer] pub fn test(ctx: &ReducerContext) { let tuple = TInts { @@ -141,6 +156,17 @@ class SqlFormat(Smoketest): se: TSimpleEnum { id: 2, action: Action::Active }, ints, }); + + ctx.db.t_arrays().insert(TArrays { + pos: vec![1, 2, 3], + velocity: vec![0.1, 0.2, 0.3], + colors: vec![vec![255, 0, 0], vec![0, 255, 0], vec![0, 0, 255]], + colors_2: vec![vec![65535, 0, 0], vec![0, 65535, 0], vec![0, 0, 65535]], + players: vec![ + TPlayer { id: 1, name: "Alice".to_string() }, + TPlayer { id: 2, name: "Bob".to_string() }, + ], + }); } """ @@ -235,6 +261,11 @@ def test_sql_format(self): -----------------------------------+-------------------------------------+--------------------------------------------------------------------------------------------------------- {"id": 1, "color": {"Gray": 128}} | {"id": 2, "action": {"Active": {}}} | {"i8": -25, "i16": -3224, "i32": -23443, "i64": -2344353, "i128": -234434897853, "i256": -234434897853} (1 row)""") + self.assertPsql(token, "SELECT * FROM t_arrays", r""" +pos | velocity | colors | colors_2 | players +---------+---------------+------------------------------+---------------------------------------------------+--------------------------------------------------------------------- + {1,2,3} | {0.1,0.2,0.3} | {0xff0000,0x00ff00,0x0000ff} | {"[65535, 0, 0]","[0, 65535, 0]","[0, 0, 65535]"} | {"{\"id\": 1,\"name\": \"Alice\"}","{\"id\": 2,\"name\": \"Bob\"}"} +(1 row)""".strip()) def test_sql_conn(self): """This test is designed to test connecting to the database and executing queries using `psycopg2`""" diff --git a/smoketests/tests/sql.py b/smoketests/tests/sql.py index 6ea5082ebe6..1f50b7db09e 100644 --- a/smoketests/tests/sql.py +++ b/smoketests/tests/sql.py @@ -4,7 +4,7 @@ class SqlFormat(Smoketest): MODULE_CODE = """ use spacetimedb::sats::{i256, u256}; -use spacetimedb::{table, ConnectionId, Identity, ReducerContext, Table, Timestamp, TimeDuration}; +use spacetimedb::{ConnectionId, Identity, ReducerContext, Table, Timestamp, TimeDuration}; #[derive(Copy, Clone)] #[spacetimedb::table(name = t_ints)] @@ -57,6 +57,21 @@ class SqlFormat(Smoketest): tuple: TOthers } +#[spacetimedb::table(name = t_player)] +pub struct TPlayer { + id: Identity, + name: String, +} + +#[spacetimedb::table(name = t_arrays)] +pub struct TArrays { + pos: Vec, + velocity: Vec, + colors: Vec>, + colors_2: Vec>, + players: Vec, +} + #[spacetimedb::reducer] pub fn test(ctx: &ReducerContext) { let tuple = TInts { @@ -94,6 +109,17 @@ class SqlFormat(Smoketest): }; ctx.db.t_others().insert(tuple.clone()); ctx.db.t_others_tuple().insert(TOthersTuple { tuple }); + + ctx.db.t_arrays().insert(TArrays { + pos: vec![1, 2, 3], + velocity: vec![0.1, 0.2, 0.3], + colors: vec![vec![255, 0, 0], vec![0, 255, 0], vec![0, 0, 255]], + colors_2: vec![vec![65535, 0, 0], vec![0, 65535, 0], vec![0, 0, 65535]], + players: vec![ + TPlayer { id: Identity::ZERO, name: "Alice".to_string() }, + TPlayer { id: Identity::ONE, name: "Bob".to_string() }, + ], + }); } """ @@ -132,3 +158,8 @@ def test_sql_format(self): ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- (bool = true, f32 = 594806.56, f64 = -3454353.345389043, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000001, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000) """) + self.assertSql("SELECT * FROM t_arrays", """\ + pos | velocity | colors | colors_2 | players +-----------+-----------------+--------------------------------+-----------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + [1, 2, 3] | [0.1, 0.2, 0.3] | [0xff0000, 0x00ff00, 0x0000ff] | [[65535, 0, 0], [0, 65535, 0], [0, 0, 65535]] | [(id = 0x0000000000000000000000000000000000000000000000000000000000000000, name = "Alice"), (id = 0x0000000000000000000000000000000000000000000000000000000000000001, name = "Bob")] +""") \ No newline at end of file