@@ -110,13 +110,36 @@ func (r *Reader) Lookup(ipAddress net.IP, result interface{}) error {
110110 if r .buffer == nil {
111111 return errors .New ("cannot call Lookup on a closed database" )
112112 }
113- pointer , err := r .lookupPointer (ipAddress )
113+ pointer , _ , _ , err := r .lookupPointer (ipAddress )
114114 if pointer == 0 || err != nil {
115115 return err
116116 }
117117 return r .retrieveData (pointer , result )
118118}
119119
120+ // LookupNetwork retrieves the database record for ipAddress and stores it in
121+ // the value pointed to be result. The network returned is the network
122+ // associated with the data record in the database. The ok return value
123+ // indicates whether the database contained a record for the ipAddress.
124+ //
125+ // If result is nil or not a pointer, an error is returned. If the data in the
126+ // database record cannot be stored in result because of type differences, an
127+ // UnmarshalTypeError is returned. If the database is invalid or otherwise
128+ // cannot be read, an InvalidDatabaseError is returned.
129+ func (r * Reader ) LookupNetwork (ipAddress net.IP , result interface {}) (network * net.IPNet , ok bool , err error ) {
130+ if r .buffer == nil {
131+ return nil , false , errors .New ("cannot call Lookup on a closed database" )
132+ }
133+ pointer , prefixLength , ipAddress , err := r .lookupPointer (ipAddress )
134+
135+ network = r .cidr (ipAddress , prefixLength )
136+ if pointer == 0 || err != nil {
137+ return network , false , err
138+ }
139+
140+ return network , true , r .retrieveData (pointer , result )
141+ }
142+
120143// LookupOffset maps an argument net.IP to a corresponding record offset in the
121144// database. NotFound is returned if no such record is found, and a record may
122145// otherwise be extracted by passing the returned offset to Decode. LookupOffset
@@ -126,13 +149,20 @@ func (r *Reader) LookupOffset(ipAddress net.IP) (uintptr, error) {
126149 if r .buffer == nil {
127150 return 0 , errors .New ("cannot call LookupOffset on a closed database" )
128151 }
129- pointer , err := r .lookupPointer (ipAddress )
152+ pointer , _ , _ , err := r .lookupPointer (ipAddress )
130153 if pointer == 0 || err != nil {
131154 return NotFound , err
132155 }
133156 return r .resolveDataPointer (pointer )
134157}
135158
159+ func (r * Reader ) cidr (ipAddress net.IP , prefixLength int ) * net.IPNet {
160+ ipBitLength := len (ipAddress ) * 8
161+ mask := net .CIDRMask (prefixLength , ipBitLength )
162+
163+ return & net.IPNet {IP : ipAddress .Mask (mask ), Mask : mask }
164+ }
165+
136166// Decode the record at |offset| into |result|. The result value pointed to
137167// must be a data value that corresponds to a record in the database. This may
138168// include a struct representation of the data, a map capable of holding the
@@ -166,24 +196,19 @@ func (r *Reader) decode(offset uintptr, result interface{}) error {
166196 return err
167197}
168198
169- func (r * Reader ) lookupPointer (ipAddress net.IP ) (uint , error ) {
199+ func (r * Reader ) lookupPointer (ipAddress net.IP ) (uint , int , net. IP , error ) {
170200 if ipAddress == nil {
171- return 0 , errors .New ("ipAddress passed to Lookup cannot be nil" )
201+ return 0 , 0 , ipAddress , errors .New ("ipAddress passed to Lookup cannot be nil" )
172202 }
173203
174204 ipV4Address := ipAddress .To4 ()
175205 if ipV4Address != nil {
176206 ipAddress = ipV4Address
177207 }
178208 if len (ipAddress ) == 16 && r .Metadata .IPVersion == 4 {
179- return 0 , fmt .Errorf ("error looking up '%s': you attempted to look up an IPv6 address in an IPv4-only database" , ipAddress .String ())
209+ return 0 , 0 , ipAddress , fmt .Errorf ("error looking up '%s': you attempted to look up an IPv6 address in an IPv4-only database" , ipAddress .String ())
180210 }
181211
182- return r .findAddressInTree (ipAddress )
183- }
184-
185- func (r * Reader ) findAddressInTree (ipAddress net.IP ) (uint , error ) {
186-
187212 bitCount := uint (len (ipAddress ) * 8 )
188213
189214 var node uint
@@ -193,23 +218,24 @@ func (r *Reader) findAddressInTree(ipAddress net.IP) (uint, error) {
193218
194219 nodeCount := r .Metadata .NodeCount
195220
196- for i := uint (0 ); i < bitCount && node < nodeCount ; i ++ {
221+ i := uint (0 )
222+ for ; i < bitCount && node < nodeCount ; i ++ {
197223 bit := uint (1 ) & (uint (ipAddress [i >> 3 ]) >> (7 - (i % 8 )))
198224
199225 var err error
200226 node , err = r .readNode (node , bit )
201227 if err != nil {
202- return 0 , err
228+ return 0 , int ( i ), ipAddress , err
203229 }
204230 }
205231 if node == nodeCount {
206232 // Record is empty
207- return 0 , nil
233+ return 0 , int ( i ), ipAddress , nil
208234 } else if node > nodeCount {
209- return node , nil
235+ return node , int ( i ), ipAddress , nil
210236 }
211237
212- return 0 , newInvalidDatabaseError ("invalid node in search tree" )
238+ return 0 , int ( i ), ipAddress , newInvalidDatabaseError ("invalid node in search tree" )
213239}
214240
215241func (r * Reader ) readNode (nodeNumber uint , index uint ) (uint , error ) {
0 commit comments