From e7e7bf54af24155c9eea0560f1012398e9394268 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Elsd=C3=B6rfer?= Date: Fri, 30 Sep 2016 14:14:46 +0200 Subject: [PATCH] Fix subtree movement on MySQL. The problem is that if a SET has multiple values in MySQL, column changes in the first clause affect the subsequent CASE statements. In other words, if we first change left_field and right_field, and then subsequently want to set tree_depth via a CASE(), the evaluation of that CASE will be affected by the left/right changes. As part of this change, I also fixed the ORM change detection for moves, which can now understand both the relationship and the id field changing. --- requirements.txt | 2 +- sqlalchemy_tree/orm.py | 81 +++++++++++++++++++-------------- sqlalchemy_tree/tests/Update.py | 69 ++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 36 deletions(-) create mode 100644 sqlalchemy_tree/tests/Update.py diff --git a/requirements.txt b/requirements.txt index 0676f00..ab4978c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -SQLAlchemy>=0.7.5 +SQLAlchemy>=1.0.10 \ No newline at end of file diff --git a/sqlalchemy_tree/orm.py b/sqlalchemy_tree/orm.py index 2ca21ff..06bd427 100644 --- a/sqlalchemy_tree/orm.py +++ b/sqlalchemy_tree/orm.py @@ -595,27 +595,30 @@ def _inter_tree_move_and_close_gap( depth = getattr(node, options.depth_field.name) gap_size = right - left + 1 + # Note: For MySQL, the order of the values in the SET clause matters. + # http://docs.sqlalchemy.org/en/latest/core/tutorial.html#updates-order-parameters + # http://dev.mysql.com/doc/refman/5.7/en/update.html connection.execute( - options.table.update() - .values({ - options.parent_id_field: sqlalchemy.case( + options.table.update(preserve_parameter_order=True) + .values([ + (options.parent_id_field, sqlalchemy.case( [(options.pk_field == getattr(node, options.pk_field.name), parent_id)], - else_=options.parent_id_field), - options.tree_id_field: sqlalchemy.case( + else_=options.parent_id_field)), + (options.tree_id_field, sqlalchemy.case( [((options.left_field >= left) & (options.left_field <= right), new_tree_id)], - else_=options.tree_id_field), - options.left_field: sqlalchemy.case( + else_=options.tree_id_field)), + (options.depth_field, sqlalchemy.case( + [((options.left_field >= left) & (options.left_field <= right), options.depth_field + depth_change)], + else_=options.depth_field)), + (options.left_field, sqlalchemy.case( [((options.left_field >= left) & (options.left_field <= right), options.left_field + left_right_change), ((options.left_field > right), options.left_field - gap_size)], - else_=options.left_field), - options.right_field: sqlalchemy.case( + else_=options.left_field)), + (options.right_field, sqlalchemy.case( [((options.right_field >= left) & (options.right_field <= right), options.right_field + left_right_change), ((options.right_field > right), options.right_field - gap_size)], - else_=options.right_field), - options.depth_field: sqlalchemy.case( - [((options.left_field >= left) & (options.left_field <= right), options.depth_field + depth_change)], - else_=options.depth_field), - }) + else_=options.right_field)) + ]) .where(options.tree_id_field == tree_id)) for obj in session_objs: obj_tree_id = getattr(obj, options.tree_id_field.name) @@ -925,24 +928,27 @@ def _move_child_within_tree( if left_right_change > 0: gap_size = -gap_size + # Note: For MySQL, the order of the values in the SET clause matters. + # http://docs.sqlalchemy.org/en/latest/core/tutorial.html#updates-order-parameters + # http://dev.mysql.com/doc/refman/5.7/en/update.html connection.execute( - options.table.update() - .values({ - options.parent_id_field: sqlalchemy.case( + options.table.update(preserve_parameter_order=True) + .values([ + (options.parent_id_field, sqlalchemy.case( [(options.pk_field == getattr(node, options.pk_field.name), parent_id)], - else_=options.parent_id_field), - options.left_field: sqlalchemy.case( + else_=options.parent_id_field)), + (options.depth_field, sqlalchemy.case( + [((options.left_field >= left) & (options.left_field <= right), options.depth_field + depth_change)], + else_=options.depth_field)), + (options.left_field, sqlalchemy.case( [((options.left_field >= left) & (options.left_field <= right), options.left_field + left_right_change), ((options.left_field >= left_boundary) & (options.left_field <= right_boundary), options.left_field + gap_size)], - else_=options.left_field), - options.right_field: sqlalchemy.case( + else_=options.left_field)), + (options.right_field, sqlalchemy.case( [((options.right_field >= left) & (options.right_field <= right), options.right_field + left_right_change), ((options.right_field >= left_boundary) & (options.right_field <= right_boundary), options.right_field + gap_size)], - else_=options.right_field), - options.depth_field: sqlalchemy.case( - [((options.left_field >= left) & (options.left_field <= right), options.depth_field + depth_change)], - else_=options.depth_field), - }) + else_=options.right_field)), + ]) .where(options.tree_id_field == tree_id)) for obj in session_objs: obj_tree_id = getattr(obj, options.tree_id_field.name) @@ -1002,13 +1008,17 @@ def before_flush(self, session, flush_context, instances): if not isinstance(node, self._node_class): continue + parent_field_changed = sqlalchemy.orm.attributes.get_history( + node, options.parent_field_name).has_changes() + parent_id_field_changed = sqlalchemy.orm.attributes.get_history( + node, options.parent_id_field.name).has_changes() + if hasattr(node, options.delayed_op_attr): setattr(node, options.delayed_op_attr, (getattr(node, options.delayed_op_attr), session_objs)) - elif (node in session.new or - sqlalchemy.orm.attributes.get_history( - node, options.parent_field_name).has_changes()): + elif (node in session.new or parent_field_changed or + parent_id_field_changed): if (hasattr(options, 'order_with_respect_to') and len(options.order_with_respect_to)): @@ -1018,12 +1028,13 @@ def before_flush(self, session, flush_context, instances): position = options.class_manager.POSITION_LAST_CHILD target = getattr(node, options.parent_field_name) target_id = getattr(node, options.parent_id_field.name) - if not target or target.id != target_id: - # If the parent relationship is not set, or changed, - # try to get it from the parent id column. - if target_id is None: - target = None - else: + + # Query for the parent node based on the id IF: + # - The object is new and we only have the target parent's id. + # - The parent id has changed, but not the parent relationship. + if (session.new and not target) or ( + parent_id_field_changed and not parent_field_changed): + if target_id: target = session.query(options.node_class).get(target_id) setattr( diff --git a/sqlalchemy_tree/tests/Update.py b/sqlalchemy_tree/tests/Update.py new file mode 100644 index 0000000..6abec8f --- /dev/null +++ b/sqlalchemy_tree/tests/Update.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +""" + sqlalchemy_tree.tests.Update + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + :copyright: (C) 2012-2014 the SQLAlchemy-ORM-Tree authors and contributors + . + :license: BSD, see LICENSE for more details. +""" + +from __future__ import absolute_import, division, print_function, \ + with_statement, unicode_literals + +from .helper import unittest, Named, db, get_tree_details +from .Named import TreeTestMixin + + +class UpdateTestCase(TreeTestMixin, unittest.TestCase): + + name_pattern = [ + (u"root1", {'id': 1, 'left': 1, 'right': 8, 'depth': 0}, [ + (u"child11", {'id': 1, 'left': 2, 'right': 3, 'depth': 1}, []), + (u"child12", {'id': 1, 'left': 4, 'right': 7, 'depth': 1}, [ + (u"child13", {'id': 1, 'left': 5, 'right': 6, 'depth': 2}, []), + ]), + ]) + ] + + def _test_move_subtree_to_root(self, arg): + result = [ + (u"root1", {'id': 1, 'left': 1, 'right': 4, 'depth': 0}, [ + (u"child11", {'id': 1, 'left': 2, 'right': 3, 'depth': 1}, []), + ]), + (u"child12", {'id': 2, 'left': 1, 'right': 4, 'depth': 0}, [ + (u"child13", {'id': 2, 'left': 2, 'right': 3, 'depth': 1}, []), + ]) + ] + + node = db.session.query(Named).filter_by(name='child12').one() + node.parent_id = None + db.session.commit() + self.assertEqual(get_tree_details(), result) + + def test_move_subtree_to_root_by_id(self): + self._test_move_subtree_to_root('parent_id') + + def test_move_subtree_to_root_by_relationship(self): + self._test_move_subtree_to_root('parent') + + def test_move_subtree_to_parent(self): + result = [ + (u"root1", {'id': 1, 'left': 1, 'right': 8, 'depth': 0}, [ + (u"child11", {'id': 1, 'left': 2, 'right': 3, 'depth': 1}, []), + (u"child12", {'id': 1, 'left': 4, 'right': 5, 'depth': 1}, []), + (u"child13", {'id': 1, 'left': 6, 'right': 7, 'depth': 1}, []), + ]) + ] + + node = db.session.query(Named).filter_by(name='child13').one() + node.parent_id = db.session.query(Named).filter_by(name='root1').one().id + db.session.commit() + self.assertEqual(get_tree_details(), result) + + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(UpdateTestCase)) + return suite