Skip to content

Commit 7060629

Browse files
authored
Merge pull request #346 from martin-neotech/4.0-protocol-handlers-io
4.0 protocol handlers io
2 parents bb028d4 + 14bbf74 commit 7060629

File tree

5 files changed

+189
-47
lines changed

5 files changed

+189
-47
lines changed

neo4j/io/__init__.py

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,55 @@ class Bolt:
9090
the handshake was carried out.
9191
"""
9292

93-
MAGIC_PREAMBLE = 0x6060B017
93+
MAGIC_PREAMBLE = b"\x60\x60\xB0\x17"
9494

9595
PROTOCOL_VERSION = None
9696

97+
@classmethod
98+
def get_handshake(cls):
99+
""" Return the supported Bolt versions as bytes.
100+
The length is 16 bytes as specified in the Bolt version negotiation.
101+
:return: bytes
102+
"""
103+
offered_versions = sorted(cls.protocol_handlers().keys(), reverse=True)[:4]
104+
return b"".join(version.to_bytes() for version in offered_versions).ljust(16, b"\x00")
105+
106+
@classmethod
107+
def protocol_handlers(cls, protocol_version=None):
108+
""" Return a dictionary of available Bolt protocol handlers,
109+
keyed by version tuple. If an explicit protocol version is
110+
provided, the dictionary will contain either zero or one items,
111+
depending on whether that version is supported. If no protocol
112+
version is provided, all available versions will be returned.
113+
114+
:param protocol_version: tuple identifying a specific protocol
115+
version (e.g. (3, 5)) or None
116+
:return: dictionary of version tuple to handler class for all
117+
relevant and supported protocol versions
118+
:raise TypeError: if protocol version is not passed in a tuple
119+
"""
120+
121+
# Carry out subclass imports locally to avoid circular
122+
# dependency issues.
123+
from neo4j.io._bolt3 import Bolt3
124+
from neo4j.io._bolt4x0 import Bolt4x0
125+
126+
handlers = {
127+
Bolt3.PROTOCOL_VERSION: Bolt3,
128+
Bolt4x0.PROTOCOL_VERSION: Bolt4x0
129+
}
130+
131+
if protocol_version is None:
132+
return handlers
133+
134+
if not isinstance(protocol_version, tuple):
135+
raise TypeError("Protocol version must be specified as a tuple")
136+
137+
if protocol_version in handlers:
138+
return {protocol_version: handlers[protocol_version]}
139+
140+
return {}
141+
97142
@classmethod
98143
def ping(cls, address, *, timeout=None, **config):
99144
""" Attempt to establish a Bolt connection, returning the
@@ -757,12 +802,18 @@ def _handshake(s, resolved_address):
757802
"""
758803
local_port = s.getsockname()[1]
759804

760-
# Send details of the protocol versions supported
761-
supported_versions = [3, 0, 0, 0]
762-
handshake = [Bolt.MAGIC_PREAMBLE] + supported_versions
763-
log.debug("[#%04X] C: <MAGIC> 0x%08X", local_port, Bolt.MAGIC_PREAMBLE)
764-
log.debug("[#%04X] C: <HANDSHAKE> 0x%08X 0x%08X 0x%08X 0x%08X", local_port, *supported_versions)
765-
data = b"".join(struct_pack(">I", num) for num in handshake)
805+
# TODO: Optimize logging code
806+
handshake = Bolt.get_handshake()
807+
import struct
808+
handshake = struct.unpack(">16B", handshake)
809+
handshake = [handshake[i:i + 4] for i in range(0, len(handshake), 4)]
810+
811+
supported_versions = [("0x%02X%02X%02X%02X" % (vx[0], vx[1], vx[2], vx[3])) for vx in handshake]
812+
813+
log.debug("[#%04X] C: <MAGIC> 0x%08X", local_port, int.from_bytes(Bolt.MAGIC_PREAMBLE, byteorder="big"))
814+
log.debug("[#%04X] C: <HANDSHAKE> %s %s %s %s", local_port, *supported_versions)
815+
816+
data = Bolt.MAGIC_PREAMBLE + Bolt.get_handshake()
766817
s.sendall(data)
767818

768819
# Handle the handshake response
@@ -796,17 +847,7 @@ def _handshake(s, resolved_address):
796847
"(looks like HTTP)".format(resolved_address))
797848
agreed_version = data[-1], data[-2]
798849
log.debug("[#%04X] S: <HANDSHAKE> 0x%06X%02X", local_port, agreed_version[1], agreed_version[0])
799-
if agreed_version == (0, 0):
800-
log.debug("[#%04X] C: <CLOSE>", local_port)
801-
s.shutdown(SHUT_RDWR)
802-
s.close()
803-
elif agreed_version in ((3, 0),):
804-
return s, agreed_version
805-
else:
806-
log.debug("[#%04X] S: <CLOSE>", local_port)
807-
s.close()
808-
raise ProtocolError("Unknown Bolt protocol version: "
809-
"{}".format(agreed_version))
850+
return s, agreed_version
810851

811852

812853
def connect(address, *, timeout=None, config):

tests/stub/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,16 @@ def __init__(self, *servers):
116116
@fixture
117117
def script():
118118
return lambda *paths: path_join(dirname(__file__), "scripts", *paths)
119+
120+
121+
@fixture
122+
def driver_info():
123+
""" Base class for test cases that integrate with a server.
124+
"""
125+
return {
126+
"uri": "bolt://localhost:7687",
127+
"bolt_routing_uri": "bolt+routing://localhost:7687",
128+
"user": "test",
129+
"password": "test",
130+
"auth_token": ("test", "test")
131+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
!: BOLT 4
2+
!: AUTO HELLO
3+
!: AUTO GOODBYE
4+
!: AUTO RESET

tests/stub/test_directdriver.py

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,42 +19,63 @@
1919
# limitations under the License.
2020

2121

22+
import pytest
23+
2224
from neo4j.exceptions import ServiceUnavailable
2325

24-
from neo4j import GraphDatabase, BoltDriver
26+
from neo4j import (
27+
GraphDatabase,
28+
BoltDriver,
29+
)
30+
31+
from tests.stub.conftest import (
32+
StubCluster,
33+
)
2534

26-
from tests.stub.conftest import StubTestCase, StubCluster
35+
# python -m pytest tests/stub/test_directdriver.py
2736

2837

29-
class BoltDriverTestCase(StubTestCase):
38+
@pytest.mark.parametrize(
39+
"test_script",
40+
[
41+
"v3/empty.script",
42+
"v4x0/empty.script",
43+
]
44+
)
45+
def test_bolt_uri_constructs_bolt_driver(driver_info, test_script):
46+
# python -m pytest tests/stub/test_directdriver.py -s -v -k test_bolt_uri_constructs_bolt_driver
47+
with StubCluster(test_script):
48+
uri = "bolt://127.0.0.1:9001"
49+
with GraphDatabase.driver(uri, auth=driver_info["auth_token"]) as driver:
50+
assert isinstance(driver, BoltDriver)
3051

31-
def test_bolt_uri_constructs_bolt_driver(self):
32-
with StubCluster("v3/empty.script"):
33-
uri = "bolt://127.0.0.1:9001"
34-
with GraphDatabase.driver(uri, auth=self.auth_token) as driver:
35-
assert isinstance(driver, BoltDriver)
3652

37-
def test_direct_disconnect_on_run(self):
38-
with StubCluster("v3/disconnect_on_run.script"):
39-
uri = "bolt://127.0.0.1:9001"
40-
with GraphDatabase.driver(uri, auth=self.auth_token) as driver:
41-
with self.assertRaises(ServiceUnavailable):
42-
with driver.session() as session:
43-
session.run("RETURN $x", {"x": 1}).consume()
53+
def test_direct_disconnect_on_run(driver_info):
54+
# python -m pytest tests/stub/test_directdriver.py -s -v -k test_direct_disconnect_on_run
55+
with StubCluster("v3/disconnect_on_run.script"):
56+
uri = "bolt://127.0.0.1:9001"
57+
with GraphDatabase.driver(uri, auth=driver_info["auth_token"]) as driver:
58+
with pytest.raises(ServiceUnavailable):
59+
with driver.session() as session:
60+
session.run("RETURN $x", {"x": 1}).consume()
4461

45-
def test_direct_disconnect_on_pull_all(self):
46-
with StubCluster("v3/disconnect_on_pull_all.script"):
47-
uri = "bolt://127.0.0.1:9001"
48-
with GraphDatabase.driver(uri, auth=self.auth_token) as driver:
49-
with self.assertRaises(ServiceUnavailable):
50-
with driver.session() as session:
51-
session.run("RETURN $x", {"x": 1}).consume()
5262

53-
def test_direct_session_close_after_server_close(self):
54-
with StubCluster("v3/disconnect_after_init.script"):
55-
uri = "bolt://127.0.0.1:9001"
56-
with GraphDatabase.driver(uri, auth=self.auth_token, max_retry_time=0,
57-
acquire_timeout=3, user_agent="test") as driver:
63+
def test_direct_disconnect_on_pull_all(driver_info):
64+
# python -m pytest tests/stub/test_directdriver.py -s -v -k test_direct_disconnect_on_pull_all
65+
with StubCluster("v3/disconnect_on_pull_all.script"):
66+
uri = "bolt://127.0.0.1:9001"
67+
with GraphDatabase.driver(uri, auth=driver_info["auth_token"]) as driver:
68+
with pytest.raises(ServiceUnavailable):
5869
with driver.session() as session:
59-
with self.assertRaises(ServiceUnavailable):
60-
session.write_transaction(lambda tx: tx.run("CREATE (a:Item)"))
70+
session.run("RETURN $x", {"x": 1}).consume()
71+
72+
73+
def test_direct_session_close_after_server_close(driver_info):
74+
# python -m pytest tests/stub/test_directdriver.py -s -v -k test_direct_session_close_after_server_close
75+
with StubCluster("v3/disconnect_after_init.script"):
76+
uri = "bolt://127.0.0.1:9001"
77+
with GraphDatabase.driver(uri, auth=driver_info["auth_token"], max_retry_time=0,
78+
acquire_timeout=3, user_agent="test") as driver:
79+
with driver.session() as session:
80+
with pytest.raises(ServiceUnavailable):
81+
session.write_transaction(lambda tx: tx.run("CREATE (a:Item)"))

tests/unit/io/test_class_bolt.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#!/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
4+
# Copyright (c) 2002-2020 "Neo4j,"
5+
# Neo4j Sweden AB [http://neo4j.com]
6+
#
7+
# This file is part of Neo4j.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
22+
import pytest
23+
from neo4j.io import Bolt
24+
25+
# python -m pytest tests/unit/io/test_class_bolt.py -s -v
26+
27+
28+
def test_class_method_protocol_handlers():
29+
# python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_class_method_protocol_handlers
30+
protocol_handlers = Bolt.protocol_handlers()
31+
assert len(protocol_handlers) == 2
32+
33+
34+
@pytest.mark.parametrize(
35+
"test_input, expected",
36+
[
37+
((0, 0), 0),
38+
((4, 0), 1),
39+
]
40+
)
41+
def test_class_method_protocol_handlers_with_protocol_version(test_input, expected):
42+
# python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_class_method_protocol_handlers_with_protocol_version
43+
protocol_handlers = Bolt.protocol_handlers(protocol_version=test_input)
44+
assert len(protocol_handlers) == expected
45+
46+
47+
def test_class_method_protocol_handlers_with_invalid_protocol_version():
48+
# python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_class_method_protocol_handlers_with_invalid_protocol_version
49+
with pytest.raises(TypeError):
50+
Bolt.protocol_handlers(protocol_version=2)
51+
52+
53+
def test_class_method_get_handshake():
54+
# python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_class_method_get_handshake
55+
handshake = Bolt.get_handshake()
56+
assert handshake == b"\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00"
57+
58+
59+
def test_magic_preamble():
60+
# python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_magic_preamble
61+
preamble = 0x6060B017
62+
preamble_bytes = preamble.to_bytes(4, byteorder="big")
63+
assert Bolt.MAGIC_PREAMBLE == preamble_bytes

0 commit comments

Comments
 (0)