1111from typing_extensions import Annotated
1212
1313from psqlpy import ConnectionPool
14- from psqlpy ._internal . extra_types import PyCustomType
14+ from psqlpy .exceptions import PyToRustValueMappingError
1515from psqlpy .extra_types import (
1616 BigInt ,
1717 Float32 ,
2020 Money ,
2121 PyBox ,
2222 PyCircle ,
23+ PyCustomType ,
2324 PyJSON ,
2425 PyJSONB ,
2526 PyLine ,
@@ -385,6 +386,41 @@ async def test_as_class(
385386 },
386387 ],
387388 ),
389+ (
390+ "JSON ARRAY" ,
391+ [
392+ [
393+ PyJSON (
394+ {
395+ "test" : ["something" , 123 , "here" ],
396+ "nested" : ["JSON" ],
397+ },
398+ ),
399+ ],
400+ [
401+ PyJSON (
402+ {
403+ "test" : ["something" , 123 , "here" ],
404+ "nested" : ["JSON" ],
405+ },
406+ ),
407+ ],
408+ ],
409+ [
410+ [
411+ {
412+ "test" : ["something" , 123 , "here" ],
413+ "nested" : ["JSON" ],
414+ },
415+ ],
416+ [
417+ {
418+ "test" : ["something" , 123 , "here" ],
419+ "nested" : ["JSON" ],
420+ },
421+ ],
422+ ],
423+ ),
388424 (
389425 "JSON ARRAY" ,
390426 [
@@ -396,6 +432,17 @@ async def test_as_class(
396432 [{"array" : "json" }, {"one more" : "test" }],
397433 ],
398434 ),
435+ (
436+ "JSON ARRAY" ,
437+ [
438+ PyJSON ([[{"array" : "json" }], [{"one more" : "test" }]]),
439+ PyJSON ([[{"array" : "json" }], [{"one more" : "test" }]]),
440+ ],
441+ [
442+ [[{"array" : "json" }], [{"one more" : "test" }]],
443+ [[{"array" : "json" }], [{"one more" : "test" }]],
444+ ],
445+ ),
399446 (
400447 "POINT ARRAY" ,
401448 [
@@ -407,6 +454,17 @@ async def test_as_class(
407454 (2.0 , 3.0 ),
408455 ],
409456 ),
457+ (
458+ "POINT ARRAY" ,
459+ [
460+ [PyPoint ([1.5 , 2 ])],
461+ [PyPoint ([2 , 3 ])],
462+ ],
463+ [
464+ [(1.5 , 2.0 )],
465+ [(2.0 , 3.0 )],
466+ ],
467+ ),
410468 (
411469 "BOX ARRAY" ,
412470 [
@@ -418,6 +476,17 @@ async def test_as_class(
418476 ((9.0 , 9.0 ), (8.5 , 8.0 )),
419477 ],
420478 ),
479+ (
480+ "BOX ARRAY" ,
481+ [
482+ [PyBox ([3.5 , 3 , 9 , 9 ])],
483+ [PyBox ([8.5 , 8 , 9 , 9 ])],
484+ ],
485+ [
486+ [((9.0 , 9.0 ), (3.5 , 3.0 ))],
487+ [((9.0 , 9.0 ), (8.5 , 8.0 ))],
488+ ],
489+ ),
421490 (
422491 "PATH ARRAY" ,
423492 [
@@ -429,6 +498,17 @@ async def test_as_class(
429498 ((3.5 , 3.0 ), (6.0 , 6.0 ), (3.5 , 3.0 )),
430499 ],
431500 ),
501+ (
502+ "PATH ARRAY" ,
503+ [
504+ [PyPath ([(3.5 , 3 ), (9 , 9 ), (8 , 8 )])],
505+ [PyPath ([(3.5 , 3 ), (6 , 6 ), (3.5 , 3 )])],
506+ ],
507+ [
508+ [[(3.5 , 3.0 ), (9.0 , 9.0 ), (8.0 , 8.0 )]],
509+ [((3.5 , 3.0 ), (6.0 , 6.0 ), (3.5 , 3.0 ))],
510+ ],
511+ ),
432512 (
433513 "LINE ARRAY" ,
434514 [
@@ -440,6 +520,17 @@ async def test_as_class(
440520 (1.0 , - 2.0 , 3.0 ),
441521 ],
442522 ),
523+ (
524+ "LINE ARRAY" ,
525+ [
526+ [PyLine ([- 2 , 1 , 2 ])],
527+ [PyLine ([1 , - 2 , 3 ])],
528+ ],
529+ [
530+ [(- 2.0 , 1.0 , 2.0 )],
531+ [(1.0 , - 2.0 , 3.0 )],
532+ ],
533+ ),
443534 (
444535 "LSEG ARRAY" ,
445536 [
@@ -451,6 +542,17 @@ async def test_as_class(
451542 [(5.6 , 3.1 ), (4.0 , 5.0 )],
452543 ],
453544 ),
545+ (
546+ "LSEG ARRAY" ,
547+ [
548+ [PyLineSegment ({(1 , 2 ), (9 , 9 )})],
549+ [PyLineSegment ([(5.6 , 3.1 ), (4 , 5 )])],
550+ ],
551+ [
552+ [[(1.0 , 2.0 ), (9.0 , 9.0 )]],
553+ [[(5.6 , 3.1 ), (4.0 , 5.0 )]],
554+ ],
555+ ),
454556 (
455557 "CIRCLE ARRAY" ,
456558 [
@@ -462,6 +564,17 @@ async def test_as_class(
462564 ((5.0 , 1.8 ), 10.0 ),
463565 ],
464566 ),
567+ (
568+ "CIRCLE ARRAY" ,
569+ [
570+ [PyCircle ([1.7 , 2.8 , 3 ])],
571+ [PyCircle ([5 , 1.8 , 10 ])],
572+ ],
573+ [
574+ [((1.7 , 2.8 ), 3.0 )],
575+ [((5.0 , 1.8 ), 10.0 )],
576+ ],
577+ ),
465578 ),
466579)
467580async def test_deserialization_simple_into_python (
@@ -529,6 +642,7 @@ async def test_deserialization_composite_into_python(
529642 circle_ CIRCLE,
530643
531644 varchar_arr VARCHAR ARRAY,
645+ varchar_arr_mdim VARCHAR ARRAY,
532646 text_arr TEXT ARRAY,
533647 bool_arr BOOL ARRAY,
534648 int2_arr INT2 ARRAY,
@@ -569,9 +683,9 @@ class TestEnum(Enum):
569683 SAD = "sad"
570684 HAPPY = "happy"
571685
572- row_values = ", " .join ([f"${ index } " for index in range (1 , 40 )])
573- row_values += ", ROW($40 , $41 ), "
574- row_values += ", " .join ([f"${ index } " for index in range (42 , 49 )])
686+ row_values = ", " .join ([f"${ index } " for index in range (1 , 41 )])
687+ row_values += ", ROW($41 , $42 ), "
688+ row_values += ", " .join ([f"${ index } " for index in range (43 , 50 )])
575689
576690 await psql_pool .execute (
577691 querystring = f"INSERT INTO for_test VALUES (ROW({ row_values } ))" ,
@@ -609,6 +723,7 @@ class TestEnum(Enum):
609723 PyLineSegment (((1.7 , 2.8 ), (9 , 9 ))),
610724 PyCircle ([1.7 , 2.8 , 3 ]),
611725 ["Some String" , "Some String" ],
726+ [["Some String" ], ["Some String" ]],
612727 [PyText ("Some String" ), PyText ("Some String" )],
613728 [True , False ],
614729 [SmallInt (123 ), SmallInt (321 )],
@@ -706,6 +821,7 @@ class ValidateModelForCustomType(BaseModel):
706821 circle_ : Tuple [Tuple [float , float ], float ]
707822
708823 varchar_arr : List [str ]
824+ varchar_arr_mdim : List [List [str ]]
709825 text_arr : List [str ]
710826 bool_arr : List [bool ]
711827 int2_arr : List [int ]
@@ -867,3 +983,21 @@ def row_factory(db_result: Dict[str, Any]) -> List[str]:
867983 assert len (as_row_factory ) == expected_number_of_elements_in_result
868984
869985 assert isinstance (as_row_factory , list )
986+
987+
988+ async def test_incorrect_dimensions_array (
989+ psql_pool : ConnectionPool ,
990+ ) -> None :
991+ await psql_pool .execute ("DROP TABLE IF EXISTS test_marr" )
992+ await psql_pool .execute ("CREATE TABLE test_marr (var_array VARCHAR ARRAY)" )
993+
994+ with pytest .raises (expected_exception = PyToRustValueMappingError ):
995+ await psql_pool .execute (
996+ querystring = "INSERT INTO test_marr VALUES ($1)" ,
997+ parameters = [
998+ [
999+ ["Len" , "is" , "Three" ],
1000+ ["Len" , "is" , "Four" , "Wow" ],
1001+ ],
1002+ ],
1003+ )
0 commit comments