11import operator
2+ import collections
23
34import pytest
45
5-
6+ import pandas as pd
7+ import pandas .util .testing as tm
68from pandas .compat import PY2 , PY36
79from pandas .tests .extension import base
810
@@ -59,27 +61,76 @@ def data_for_grouping():
5961 ])
6062
6163
62- class TestDtype (base .BaseDtypeTests ):
64+ class BaseJSON (object ):
65+ # NumPy doesn't handle an array of equal-length UserDicts.
66+ # The default assert_series_equal eventually does a
67+ # Series.values, which raises. We work around it by
68+ # converting the UserDicts to dicts.
69+ def assert_series_equal (self , left , right , ** kwargs ):
70+ if left .dtype .name == 'json' :
71+ assert left .dtype == right .dtype
72+ left = pd .Series (JSONArray (left .values .astype (object )),
73+ index = left .index , name = left .name )
74+ right = pd .Series (JSONArray (right .values .astype (object )),
75+ index = right .index , name = right .name )
76+ tm .assert_series_equal (left , right , ** kwargs )
77+
78+ def assert_frame_equal (self , left , right , * args , ** kwargs ):
79+ tm .assert_index_equal (
80+ left .columns , right .columns ,
81+ exact = kwargs .get ('check_column_type' , 'equiv' ),
82+ check_names = kwargs .get ('check_names' , True ),
83+ check_exact = kwargs .get ('check_exact' , False ),
84+ check_categorical = kwargs .get ('check_categorical' , True ),
85+ obj = '{obj}.columns' .format (obj = kwargs .get ('obj' , 'DataFrame' )))
86+
87+ jsons = (left .dtypes == 'json' ).index
88+
89+ for col in jsons :
90+ self .assert_series_equal (left [col ], right [col ],
91+ * args , ** kwargs )
92+
93+ left = left .drop (columns = jsons )
94+ right = right .drop (columns = jsons )
95+ tm .assert_frame_equal (left , right , * args , ** kwargs )
96+
97+
98+ class TestDtype (BaseJSON , base .BaseDtypeTests ):
6399 pass
64100
65101
66- class TestInterface (base .BaseInterfaceTests ):
67- pass
102+ class TestInterface (BaseJSON , base .BaseInterfaceTests ):
103+ def test_custom_asserts (self ):
104+ # This would always trigger the KeyError from trying to put
105+ # an array of equal-length UserDicts inside an ndarray.
106+ data = JSONArray ([collections .UserDict ({'a' : 1 }),
107+ collections .UserDict ({'b' : 2 }),
108+ collections .UserDict ({'c' : 3 })])
109+ a = pd .Series (data )
110+ self .assert_series_equal (a , a )
111+ self .assert_frame_equal (a .to_frame (), a .to_frame ())
112+
113+ b = pd .Series (data .take ([0 , 0 , 1 ]))
114+ with pytest .raises (AssertionError ):
115+ self .assert_series_equal (a , b )
116+
117+ with pytest .raises (AssertionError ):
118+ self .assert_frame_equal (a .to_frame (), b .to_frame ())
68119
69120
70- class TestConstructors (base .BaseConstructorsTests ):
121+ class TestConstructors (BaseJSON , base .BaseConstructorsTests ):
71122 pass
72123
73124
74- class TestReshaping (base .BaseReshapingTests ):
125+ class TestReshaping (BaseJSON , base .BaseReshapingTests ):
75126 pass
76127
77128
78- class TestGetitem (base .BaseGetitemTests ):
129+ class TestGetitem (BaseJSON , base .BaseGetitemTests ):
79130 pass
80131
81132
82- class TestMissing (base .BaseMissingTests ):
133+ class TestMissing (BaseJSON , base .BaseMissingTests ):
83134 @pytest .mark .xfail (reason = "Setting a dict as a scalar" )
84135 def test_fillna_series (self ):
85136 """We treat dictionaries as a mapping in fillna, not a scalar."""
@@ -94,7 +145,7 @@ def test_fillna_frame(self):
94145 reason = "Dictionary order unstable" )
95146
96147
97- class TestMethods (base .BaseMethodsTests ):
148+ class TestMethods (BaseJSON , base .BaseMethodsTests ):
98149 @unhashable
99150 def test_value_counts (self , all_data , dropna ):
100151 pass
@@ -126,7 +177,7 @@ def test_sort_values_missing(self, data_missing_for_sorting, ascending):
126177 data_missing_for_sorting , ascending )
127178
128179
129- class TestCasting (base .BaseCastingTests ):
180+ class TestCasting (BaseJSON , base .BaseCastingTests ):
130181 @pytest .mark .xfail
131182 def test_astype_str (self ):
132183 """This currently fails in NumPy on np.array(self, dtype=str) with
@@ -139,7 +190,7 @@ def test_astype_str(self):
139190# internals has trouble setting sequences of values into scalar positions.
140191
141192
142- class TestGroupby (base .BaseGroupbyTests ):
193+ class TestGroupby (BaseJSON , base .BaseGroupbyTests ):
143194
144195 @unhashable
145196 def test_groupby_extension_transform (self ):
0 commit comments