Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
SQLAlchemy>=0.7.5
SQLAlchemy>=1.0.10
81 changes: 46 additions & 35 deletions sqlalchemy_tree/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,27 +595,30 @@ def _inter_tree_move_and_close_gap(
depth = getattr(node, options.depth_field.name)
gap_size = right - left + 1

# Note: For MySQL, the order of the values in the SET clause matters.
# http://docs.sqlalchemy.org/en/latest/core/tutorial.html#updates-order-parameters
# http://dev.mysql.com/doc/refman/5.7/en/update.html
connection.execute(
options.table.update()
.values({
options.parent_id_field: sqlalchemy.case(
options.table.update(preserve_parameter_order=True)
.values([
(options.parent_id_field, sqlalchemy.case(
[(options.pk_field == getattr(node, options.pk_field.name), parent_id)],
else_=options.parent_id_field),
options.tree_id_field: sqlalchemy.case(
else_=options.parent_id_field)),
(options.tree_id_field, sqlalchemy.case(
[((options.left_field >= left) & (options.left_field <= right), new_tree_id)],
else_=options.tree_id_field),
options.left_field: sqlalchemy.case(
else_=options.tree_id_field)),
(options.depth_field, sqlalchemy.case(
[((options.left_field >= left) & (options.left_field <= right), options.depth_field + depth_change)],
else_=options.depth_field)),
(options.left_field, sqlalchemy.case(
[((options.left_field >= left) & (options.left_field <= right), options.left_field + left_right_change),
((options.left_field > right), options.left_field - gap_size)],
else_=options.left_field),
options.right_field: sqlalchemy.case(
else_=options.left_field)),
(options.right_field, sqlalchemy.case(
[((options.right_field >= left) & (options.right_field <= right), options.right_field + left_right_change),
((options.right_field > right), options.right_field - gap_size)],
else_=options.right_field),
options.depth_field: sqlalchemy.case(
[((options.left_field >= left) & (options.left_field <= right), options.depth_field + depth_change)],
else_=options.depth_field),
})
else_=options.right_field))
])
.where(options.tree_id_field == tree_id))
for obj in session_objs:
obj_tree_id = getattr(obj, options.tree_id_field.name)
Expand Down Expand Up @@ -925,24 +928,27 @@ def _move_child_within_tree(
if left_right_change > 0:
gap_size = -gap_size

# Note: For MySQL, the order of the values in the SET clause matters.
# http://docs.sqlalchemy.org/en/latest/core/tutorial.html#updates-order-parameters
# http://dev.mysql.com/doc/refman/5.7/en/update.html
connection.execute(
options.table.update()
.values({
options.parent_id_field: sqlalchemy.case(
options.table.update(preserve_parameter_order=True)
.values([
(options.parent_id_field, sqlalchemy.case(
[(options.pk_field == getattr(node, options.pk_field.name), parent_id)],
else_=options.parent_id_field),
options.left_field: sqlalchemy.case(
else_=options.parent_id_field)),
(options.depth_field, sqlalchemy.case(
[((options.left_field >= left) & (options.left_field <= right), options.depth_field + depth_change)],
else_=options.depth_field)),
(options.left_field, sqlalchemy.case(
[((options.left_field >= left) & (options.left_field <= right), options.left_field + left_right_change),
((options.left_field >= left_boundary) & (options.left_field <= right_boundary), options.left_field + gap_size)],
else_=options.left_field),
options.right_field: sqlalchemy.case(
else_=options.left_field)),
(options.right_field, sqlalchemy.case(
[((options.right_field >= left) & (options.right_field <= right), options.right_field + left_right_change),
((options.right_field >= left_boundary) & (options.right_field <= right_boundary), options.right_field + gap_size)],
else_=options.right_field),
options.depth_field: sqlalchemy.case(
[((options.left_field >= left) & (options.left_field <= right), options.depth_field + depth_change)],
else_=options.depth_field),
})
else_=options.right_field)),
])
.where(options.tree_id_field == tree_id))
for obj in session_objs:
obj_tree_id = getattr(obj, options.tree_id_field.name)
Expand Down Expand Up @@ -1002,13 +1008,17 @@ def before_flush(self, session, flush_context, instances):
if not isinstance(node, self._node_class):
continue

parent_field_changed = sqlalchemy.orm.attributes.get_history(
node, options.parent_field_name).has_changes()
parent_id_field_changed = sqlalchemy.orm.attributes.get_history(
node, options.parent_id_field.name).has_changes()

if hasattr(node, options.delayed_op_attr):
setattr(node, options.delayed_op_attr,
(getattr(node, options.delayed_op_attr), session_objs))

elif (node in session.new or
sqlalchemy.orm.attributes.get_history(
node, options.parent_field_name).has_changes()):
elif (node in session.new or parent_field_changed or
parent_id_field_changed):

if (hasattr(options, 'order_with_respect_to') and
len(options.order_with_respect_to)):
Expand All @@ -1018,12 +1028,13 @@ def before_flush(self, session, flush_context, instances):
position = options.class_manager.POSITION_LAST_CHILD
target = getattr(node, options.parent_field_name)
target_id = getattr(node, options.parent_id_field.name)
if not target or target.id != target_id:
# If the parent relationship is not set, or changed,
# try to get it from the parent id column.
if target_id is None:
target = None
else:

# Query for the parent node based on the id IF:
# - The object is new and we only have the target parent's id.
# - The parent id has changed, but not the parent relationship.
if (session.new and not target) or (
parent_id_field_changed and not parent_field_changed):
if target_id:
target = session.query(options.node_class).get(target_id)

setattr(
Expand Down
69 changes: 69 additions & 0 deletions sqlalchemy_tree/tests/Update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
"""
sqlalchemy_tree.tests.Update
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

:copyright: (C) 2012-2014 the SQLAlchemy-ORM-Tree authors and contributors
<see AUTHORS file>.
:license: BSD, see LICENSE for more details.
"""

from __future__ import absolute_import, division, print_function, \
with_statement, unicode_literals

from .helper import unittest, Named, db, get_tree_details
from .Named import TreeTestMixin


class UpdateTestCase(TreeTestMixin, unittest.TestCase):

name_pattern = [
(u"root1", {'id': 1, 'left': 1, 'right': 8, 'depth': 0}, [
(u"child11", {'id': 1, 'left': 2, 'right': 3, 'depth': 1}, []),
(u"child12", {'id': 1, 'left': 4, 'right': 7, 'depth': 1}, [
(u"child13", {'id': 1, 'left': 5, 'right': 6, 'depth': 2}, []),
]),
])
]

def _test_move_subtree_to_root(self, arg):
result = [
(u"root1", {'id': 1, 'left': 1, 'right': 4, 'depth': 0}, [
(u"child11", {'id': 1, 'left': 2, 'right': 3, 'depth': 1}, []),
]),
(u"child12", {'id': 2, 'left': 1, 'right': 4, 'depth': 0}, [
(u"child13", {'id': 2, 'left': 2, 'right': 3, 'depth': 1}, []),
])
]

node = db.session.query(Named).filter_by(name='child12').one()
node.parent_id = None
db.session.commit()
self.assertEqual(get_tree_details(), result)

def test_move_subtree_to_root_by_id(self):
self._test_move_subtree_to_root('parent_id')

def test_move_subtree_to_root_by_relationship(self):
self._test_move_subtree_to_root('parent')

def test_move_subtree_to_parent(self):
result = [
(u"root1", {'id': 1, 'left': 1, 'right': 8, 'depth': 0}, [
(u"child11", {'id': 1, 'left': 2, 'right': 3, 'depth': 1}, []),
(u"child12", {'id': 1, 'left': 4, 'right': 5, 'depth': 1}, []),
(u"child13", {'id': 1, 'left': 6, 'right': 7, 'depth': 1}, []),
])
]

node = db.session.query(Named).filter_by(name='child13').one()
node.parent_id = db.session.query(Named).filter_by(name='root1').one().id
db.session.commit()
self.assertEqual(get_tree_details(), result)



def suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(UpdateTestCase))
return suite