diff --git a/patsy/eval.py b/patsy/eval.py index d4ed83f..5a0f07f 100644 --- a/patsy/eval.py +++ b/patsy/eval.py @@ -24,6 +24,7 @@ from patsy.tokens import (pretty_untokenize, normalize_token_spacing, python_tokenize) from patsy.compat import call_and_wrap_exc +from patsy.version import __version__ def _all_future_flags(): flags = 0 @@ -565,7 +566,41 @@ def eval(self, memorize_state, data): memorize_state, data) - __getstate__ = no_pickling + def __getstate__(self): + return { + 'version': __version__, + 'code': self.code, + 'origin': self.origin + } + + def __setstate__(self, state): + expected_fields = { + 'code': 'REQUIRED', + 'origin': 'OPTIONAL' + } + + for field in expected_fields: + if field in state: + self.__setattr__(field, state[field]) + continue + else: + pickling_version = state['version'] + unpickling_newer_version = pickling_version.split('+')[0] > __version__.split('+')[0] + if expected_fields[field] == 'REQUIRED' and unpickling_newer_version: + msg = "This EvalFactor was pickled with patsy version %s," \ + "and cannot be unpickled with version %s" % \ + (pickling_version, __version__) + raise KeyError, msg + elif expected_fields[field] == 'OPTIONAL' and unpickling_newer_version: + msg = "This EvalFactor was pickled with patsy version %s," \ + "and cannot be unpickled with full fidelity by version %s." \ + "In particular, you have access to `code` but not to `origin`" % \ + (pickling_version, __version__) + raise FutureWarning, msg + else: + msg = "Unable to unpickle EvalFactor field %s." % field + raise KeyError, msg + def test_EvalFactor_basics(): e = EvalFactor("a+b") @@ -577,8 +612,6 @@ def test_EvalFactor_basics(): assert e.origin is None assert e2.origin == "asdf" - assert_no_pickling(e) - def test_EvalFactor_memorize_passes_needed(): from patsy.state import stateful_transform foo = stateful_transform(lambda: "FOO-OBJ") diff --git a/patsy/test_highlevel.py b/patsy/test_highlevel.py index 9b7be37..2dc45fd 100644 --- a/patsy/test_highlevel.py +++ b/patsy/test_highlevel.py @@ -5,6 +5,7 @@ # Exhaustive end-to-end tests of the top-level API. import sys +from six.moves import cPickle as pickle import __future__ import six import numpy as np @@ -758,3 +759,23 @@ def test_C_and_pandas_categorical(): [[1, 0], [1, 1], [1, 0]]) + +def test_pickle_builder_roundtrips(): + import numpy as np + # TODO Add center(x) and categorical interaction, and call to np.log to patsy formula. + design_matrix = dmatrix("x + a", {"x": [1, 2, 3], + "a": ["a1", "a2", "a3"]}) + # TODO Remove builder, pass design_info to dmatrix() instead. + builder = design_matrix.design_info.builder + del np + + new_data = {"x": [10, 20, 30], + "a": ["a3", "a1", "a2"]} + m1 = dmatrix(builder, new_data) + + builder2 = pickle.loads(pickle.dumps(design_matrix.design_info.builder)) + m2 = dmatrix(builder2, new_data) + + assert np.allclose(m1, m2) + + diff --git a/patsy/test_pickling.py b/patsy/test_pickling.py new file mode 100644 index 0000000..5a3943c --- /dev/null +++ b/patsy/test_pickling.py @@ -0,0 +1,19 @@ +from six.moves import cPickle as pickle + +from patsy.eval import EvalFactor +from patsy.version import __version__ + + +objects_to_test = [ + ("EvalFactor('a+b', 'mars')", { + "0.4.1+dev": "ccopy_reg\n_reconstructor\np1\n(cpatsy.eval\nEvalFactor\np2\nc__builtin__\nobject\np3\nNtRp4\n(dp5\nS\'code\'\np6\nS\'a + b\'\np7\nsS\'origin\'\np8\nS\'mars\'\np9\nsS\'version\'\np10\nS\'0.4.1+dev\'\np11\nsb." + }) + ] + +def test_pickling_roundtrips(): + for obj_code, pickled_history in objects_to_test: + obj = eval(obj_code) + print pickle.dumps(obj).encode('string-escape') + assert obj == pickle.loads(pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)) + for version, pickled in pickled_history.items(): + assert pickle.dumps(obj) == pickled