|
1 | 1 | from functools import partial |
2 | 2 |
|
3 | 3 | from django.db.models.query import QuerySet |
| 4 | +from graphene import NonNull |
4 | 5 |
|
5 | 6 | from promise import Promise |
6 | 7 |
|
@@ -45,17 +46,31 @@ def type(self): |
45 | 46 | from .types import DjangoObjectType |
46 | 47 |
|
47 | 48 | _type = super(ConnectionField, self).type |
| 49 | + non_null = False |
| 50 | + if isinstance(_type, NonNull): |
| 51 | + _type = _type.of_type |
| 52 | + non_null = True |
48 | 53 | assert issubclass( |
49 | 54 | _type, DjangoObjectType |
50 | 55 | ), "DjangoConnectionField only accepts DjangoObjectType types" |
51 | 56 | assert _type._meta.connection, "The type {} doesn't have a connection".format( |
52 | 57 | _type.__name__ |
53 | 58 | ) |
54 | | - return _type._meta.connection |
| 59 | + connection_type = _type._meta.connection |
| 60 | + if non_null: |
| 61 | + return NonNull(connection_type) |
| 62 | + return connection_type |
| 63 | + |
| 64 | + @property |
| 65 | + def connection_type(self): |
| 66 | + type = self.type |
| 67 | + if isinstance(type, NonNull): |
| 68 | + return type.of_type |
| 69 | + return type |
55 | 70 |
|
56 | 71 | @property |
57 | 72 | def node_type(self): |
58 | | - return self.type._meta.node |
| 73 | + return self.connection_type._meta.node |
59 | 74 |
|
60 | 75 | @property |
61 | 76 | def model(self): |
@@ -103,15 +118,15 @@ def resolve_connection(cls, connection, default_manager, args, iterable): |
103 | 118 |
|
104 | 119 | @classmethod |
105 | 120 | def connection_resolver( |
106 | | - cls, |
107 | | - resolver, |
108 | | - connection, |
109 | | - default_manager, |
110 | | - max_limit, |
111 | | - enforce_first_or_last, |
112 | | - root, |
113 | | - info, |
114 | | - **args |
| 121 | + cls, |
| 122 | + resolver, |
| 123 | + connection, |
| 124 | + default_manager, |
| 125 | + max_limit, |
| 126 | + enforce_first_or_last, |
| 127 | + root, |
| 128 | + info, |
| 129 | + **args |
115 | 130 | ): |
116 | 131 | first = args.get("first") |
117 | 132 | last = args.get("last") |
@@ -146,7 +161,7 @@ def get_resolver(self, parent_resolver): |
146 | 161 | return partial( |
147 | 162 | self.connection_resolver, |
148 | 163 | parent_resolver, |
149 | | - self.type, |
| 164 | + self.connection_type, |
150 | 165 | self.get_manager(), |
151 | 166 | self.max_limit, |
152 | 167 | self.enforce_first_or_last, |
|
0 commit comments