Skip to content

Commit d36334e

Browse files
committed
Removed AddressList and added custom resolution to Address class
1 parent ac049dd commit d36334e

File tree

3 files changed

+40
-83
lines changed

3 files changed

+40
-83
lines changed

neo4j/addressing.py

Lines changed: 36 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,19 @@ def host(self):
9797
def port(self):
9898
return self[1]
9999

100-
def resolve(self, family=0):
101-
# TODO: custom resolver argument
100+
@classmethod
101+
def _dns_resolve(cls, address, family=0):
102+
""" Regular DNS resolver. Takes an address object and optional
103+
address family for filtering.
104+
105+
:param address:
106+
:param family:
107+
:return:
108+
"""
102109
try:
103-
info = getaddrinfo(self.host, self.port, family, SOCK_STREAM)
110+
info = getaddrinfo(address.host, address.port, family, SOCK_STREAM)
104111
except OSError:
105-
raise ValueError("Cannot resolve address {}".format(self))
112+
raise ValueError("Cannot resolve address {}".format(address))
106113
else:
107114
resolved = []
108115
for fam, _, _, _, addr in info:
@@ -114,6 +121,31 @@ def resolve(self, family=0):
114121
resolved.append(Address(addr))
115122
return resolved
116123

124+
def resolve(self, family=0, resolver=None):
125+
""" Carry out domain name resolution on this Address object.
126+
127+
If a resolver function is supplied, and is callable, this is
128+
called first, with this object as its argument. This may yield
129+
multiple output addresses, which are chained into a subsequent
130+
regular DNS resolution call. If no resolver function is passed,
131+
the DNS resolution is carried out on the original Address
132+
object.
133+
134+
This function returns a list of resolved Address objects.
135+
136+
:param family: optional address family to filter resolved
137+
addresses by (e.g. AF_INET6)
138+
:param resolver: optional customer resolver function to be
139+
called before regular DNS resolution
140+
"""
141+
resolved = []
142+
if resolver:
143+
for address in map(Address, resolver(self)):
144+
resolved.extend(self._dns_resolve(address, family))
145+
else:
146+
resolved.extend(self._dns_resolve(self, family))
147+
return resolved
148+
117149
@property
118150
def port_number(self):
119151
try:
@@ -141,75 +173,3 @@ class IPv6Address(Address):
141173

142174
def __str__(self):
143175
return "[{}]:{}".format(*self)
144-
145-
146-
# TODO: deprecate
147-
class AddressList(list):
148-
""" A list of socket addresses, each as a tuple of the format expected by
149-
the built-in `socket.connect` method.
150-
"""
151-
152-
@classmethod
153-
def parse(cls, s, default_host=None, default_port=None):
154-
""" Parse a string containing one or more socket addresses, each
155-
separated by whitespace.
156-
"""
157-
if isinstance(s, str):
158-
return cls([Address.parse(a, default_host, default_port)
159-
for a in s.split()])
160-
else:
161-
raise TypeError("AddressList.parse requires a string argument")
162-
163-
def __init__(self, iterable=None):
164-
super().__init__(map(Address, iterable or ()))
165-
166-
def __str__(self):
167-
return " ".join(str(Address(_)) for _ in self)
168-
169-
def __repr__(self):
170-
return "{}({!r})".format(self.__class__.__name__, list(self))
171-
172-
def dns_resolve(self, family=0):
173-
""" Resolve all addresses into one or more resolved address tuples
174-
using DNS. Each host name will resolve into one or more IP addresses,
175-
limited by the given address `family` (if any). Each port value
176-
(either integer or string) will resolve into an integer port value
177-
(e.g. 'http' will resolve to 80).
178-
179-
>>> a = AddressList([("localhost", "http")])
180-
>>> a.dns_resolve()
181-
>>> a
182-
AddressList([('::1', 80, 0, 0), ('127.0.0.1', 80)])
183-
184-
"""
185-
resolved = []
186-
for address in iter(self):
187-
host = address[0]
188-
port = address[1]
189-
try:
190-
info = getaddrinfo(host, port, family, SOCK_STREAM)
191-
except OSError:
192-
raise ValueError("Cannot resolve address {!r}".format(address))
193-
else:
194-
for _, _, _, _, addr in info:
195-
if len(address) == 4 and address[3] != 0:
196-
# skip any IPv6 addresses with a non-zero scope id
197-
# as these appear to cause problems on some platforms
198-
continue
199-
if addr not in resolved:
200-
resolved.append(addr)
201-
self[:] = resolved
202-
203-
def custom_resolve(self, resolver):
204-
""" Perform custom resolution on the contained addresses using a
205-
resolver function.
206-
207-
:return:
208-
"""
209-
if not callable(resolver):
210-
return
211-
new_addresses = []
212-
for address in iter(self):
213-
for new_address in resolver(address):
214-
new_addresses.append(new_address)
215-
self[:] = new_addresses

neo4j/io/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
Condition,
6767
)
6868

69-
from neo4j.addressing import AddressList
69+
from neo4j.addressing import Address
7070
from neo4j.conf import PoolConfig
7171
from neo4j.errors import BoltRoutingError, Neo4jAvailabilityError
7272
from neo4j.exceptions import (
@@ -798,10 +798,8 @@ def connect(address, *, timeout=None, config):
798798
# Catches refused connections see:
799799
# https://docs.python.org/2/library/errno.html
800800
log.debug("[#0000] C: <RESOLVE> %s", address)
801-
address_list = AddressList([address])
802-
address_list.custom_resolve(config.get("resolver"))
803-
address_list.dns_resolve()
804-
for resolved_address in address_list:
801+
custom_resolver = config.get("resolver")
802+
for resolved_address in Address(address).resolve(resolver=custom_resolver):
805803
s = None
806804
try:
807805
host = address[0]

tests/integration/conftest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from pytest import fixture, skip
2828

2929
from neo4j import GraphDatabase
30-
from neo4j.addressing import AddressList
3130
from neo4j.exceptions import ServiceUnavailable
3231
from neo4j.io import Bolt
3332

@@ -145,7 +144,7 @@ def __init__(self, host, core_port_range, replica_port_range):
145144

146145
@property
147146
def addresses(self):
148-
return AddressList(machine.address for machine in self.cores())
147+
return [machine.address for machine in self.cores()]
149148

150149
@property
151150
def auth(self):

0 commit comments

Comments
 (0)