diff --git a/dropshot/src/api_description.rs b/dropshot/src/api_description.rs index 3f832982b..480eed911 100644 --- a/dropshot/src/api_description.rs +++ b/dropshot/src/api_description.rs @@ -670,15 +670,15 @@ impl ApiDescription { _ => panic!("reference not expected"), }; - let method_ref = match &method[..] { - "GET" => &mut pathitem.get, - "PUT" => &mut pathitem.put, - "POST" => &mut pathitem.post, - "DELETE" => &mut pathitem.delete, - "OPTIONS" => &mut pathitem.options, - "HEAD" => &mut pathitem.head, - "PATCH" => &mut pathitem.patch, - "TRACE" => &mut pathitem.trace, + let method_ref = match method { + http::Method::GET => &mut pathitem.get, + http::Method::PUT => &mut pathitem.put, + http::Method::POST => &mut pathitem.post, + http::Method::DELETE => &mut pathitem.delete, + http::Method::OPTIONS => &mut pathitem.options, + http::Method::HEAD => &mut pathitem.head, + http::Method::PATCH => &mut pathitem.patch, + http::Method::TRACE => &mut pathitem.trace, other => panic!("unexpected method `{}`", other), }; let mut operation = openapiv3::Operation::default(); diff --git a/dropshot/src/router.rs b/dropshot/src/router.rs index ed3af536d..5dc273553 100644 --- a/dropshot/src/router.rs +++ b/dropshot/src/router.rs @@ -12,8 +12,10 @@ use crate::ApiEndpointBodyContentType; use http::Method; use http::StatusCode; use percent_encoding::percent_decode_str; +use std::borrow::Cow; use std::collections::BTreeMap; use std::collections::BTreeSet; +use std::collections::HashMap; use std::sync::Arc; /// `HttpRouter` is a simple data structure for routing incoming HTTP requests to @@ -81,7 +83,7 @@ pub struct HttpRouter { #[derive(Debug)] struct HttpRouterNode { /// Handlers, etc. for each of the HTTP methods defined for this node. - methods: BTreeMap>, + methods: HashMap>, /// Edges linking to child nodes. edges: Option>, } @@ -217,7 +219,7 @@ pub struct RouterLookupResult { impl HttpRouterNode { pub fn new() -> Self { - HttpRouterNode { methods: BTreeMap::new(), edges: None } + HttpRouterNode { methods: HashMap::new(), edges: None } } } @@ -385,8 +387,7 @@ impl HttpRouter { }; } - let methodname = method.as_str().to_uppercase(); - if node.methods.contains_key(&methodname) { + if node.methods.contains_key(&method) { panic!( "URI path \"{}\": attempted to create duplicate route for \ method \"{}\"", @@ -394,7 +395,7 @@ impl HttpRouter { ); } - node.methods.insert(methodname, endpoint); + node.methods.insert(method, endpoint); } /// Look up the route handler for an HTTP request having method `method` and @@ -408,36 +409,41 @@ impl HttpRouter { method: &Method, path: InputPath<'_>, ) -> Result, HttpError> { - let all_segments = input_path_to_segments(&path).map_err(|_| { - HttpError::for_bad_request( - None, - String::from("invalid path encoding"), - ) - })?; - let mut all_segments = all_segments.into_iter(); + let mut all_segments = input_path_to_segments(&path); let mut node = &self.root; let mut variables = VariableSet::new(); - while let Some(segment) = all_segments.next() { - let segment_string = segment.to_string(); + while let Some(maybe_segment) = all_segments.next() { + let segment = maybe_segment.map_err(|e| { + HttpError::for_bad_request( + None, + format!("invalid path encoding: {e}"), + ) + })?; node = match &node.edges { None => None, Some(HttpRouterEdges::Literals(edges)) => { - edges.get(&segment_string) + edges.get(segment.as_ref()) } Some(HttpRouterEdges::VariableSingle(varname, ref node)) => { variables.insert( varname.clone(), - VariableValue::String(segment_string), + VariableValue::String(segment.into_owned()), ); Some(node) } Some(HttpRouterEdges::VariableRest(varname, node)) => { - let mut rest = vec![segment]; - while let Some(segment) = all_segments.next() { - rest.push(segment); + let mut rest = vec![segment.into_owned()]; + while let Some(maybe_segment) = all_segments.next() { + let segment = maybe_segment.map_err(|e| { + HttpError::for_bad_request( + None, + format!("invalid path encoding: {e}"), + ) + })?; + rest.push(segment.into_owned()); } variables.insert( varname.clone(), @@ -478,9 +484,8 @@ impl HttpRouter { )); } - let methodname = method.as_str().to_uppercase(); node.methods - .get(&methodname) + .get(&method) .map(|handler| RouterLookupResult { handler: Arc::clone(&handler.handler), operation_id: handler.operation_id.clone(), @@ -512,7 +517,7 @@ fn insert_var( } impl<'a, Context: ServerContext> IntoIterator for &'a HttpRouter { - type Item = (String, String, &'a ApiEndpoint); + type Item = (String, Method, &'a ApiEndpoint); type IntoIter = HttpRouterIter<'a, Context>; fn into_iter(self) -> Self::IntoIter { HttpRouterIter::new(self) @@ -529,7 +534,7 @@ impl<'a, Context: ServerContext> IntoIterator for &'a HttpRouter { /// blank string and an iterator over the root node's children. pub struct HttpRouterIter<'a, Context: ServerContext> { method: - Box)> + 'a>, + Box)> + 'a>, path: Vec<(PathSegment, Box>)>, } type PathIter<'a, Context> = @@ -592,7 +597,7 @@ impl<'a, Context: ServerContext> HttpRouterIter<'a, Context> { } impl<'a, Context: ServerContext> Iterator for HttpRouterIter<'a, Context> { - type Item = (String, String, &'a ApiEndpoint); + type Item = (String, Method, &'a ApiEndpoint); fn next(&mut self) -> Option { // If there are no path components left then we've reached the end of @@ -630,6 +635,14 @@ impl<'a, Context: ServerContext> Iterator for HttpRouterIter<'a, Context> { } } +#[derive(Debug, thiserror::Error)] +enum InputPathError { + #[error(transparent)] + PercentDecode(#[from] std::str::Utf8Error), + #[error("dot-segments are not permitted")] + DotSegment, +} + /// Helper function for taking a Uri path and producing a `Vec` of /// URL-decoded strings, each representing one segment of the path. The input is /// percent-encoded. Empty segments i.e. due to consecutive "/" characters or a @@ -653,7 +666,9 @@ impl<'a, Context: ServerContext> Iterator for HttpRouterIter<'a, Context> { /// that consumers may be susceptible to other information leaks, for example /// if a client were able to follow a symlink to the root of the filesystem. As /// always, it is incumbent on the consumer and *critical* to validate input. -fn input_path_to_segments(path: &InputPath) -> Result, String> { +fn input_path_to_segments<'path>( + path: &'path InputPath, +) -> impl Iterator, InputPathError>> + 'path { // We're given the "path" portion of a URI and we want to construct an // array of the segments of the path. Relevant references: // @@ -682,17 +697,12 @@ fn input_path_to_segments(path: &InputPath) -> Result, String> { // should be ignored). The net result is that that crate doesn't buy us // much here, but it does create more work, so we'll just split it // ourselves. - path.0 - .split('/') - .filter(|segment| !segment.is_empty()) - .map(|segment| match segment { - "." | ".." => Err("dot-segments are not permitted".to_string()), - _ => Ok(percent_decode_str(segment) - .decode_utf8() - .map_err(|e| e.to_string())? - .to_string()), - }) - .collect() + path.0.split('/').filter(|segment| !segment.is_empty()).map(|segment| { + match segment { + "." | ".." => Err(InputPathError::DotSegment), + _ => Ok(percent_decode_str(segment).decode_utf8()?), + } + }) } /// Whereas in `input_path_to_segments()` we must accommodate any user input, when @@ -729,6 +739,7 @@ mod test { use super::super::handler::RouteHandler; use super::input_path_to_segments; use super::HttpRouter; + use super::InputPathError; use super::PathSegment; use crate::api_description::ApiEndpointBodyContentType; use crate::from_map::from_map; @@ -1309,10 +1320,10 @@ mod test { assert_eq!( ret, vec![ - ("/".to_string(), "GET".to_string(),), + ("/".to_string(), http::Method::GET,), ( "/projects/{project_id}/instances".to_string(), - "GET".to_string(), + http::Method::GET, ), ] ); @@ -1335,16 +1346,18 @@ mod test { assert_eq!( ret, vec![ - ("/".to_string(), "GET".to_string(),), - ("/".to_string(), "POST".to_string(),), + ("/".to_string(), http::Method::GET,), + ("/".to_string(), http::Method::POST), ] ); } #[test] fn test_segments() { - let segs = - input_path_to_segments(&"//foo/bar/baz%2fbuzz".into()).unwrap(); + let segs = input_path_to_segments(&"//foo/bar/baz%2fbuzz".into()) + .map(|seg| Ok::(seg?.into_owned())) + .collect::, _>>() + .unwrap(); assert_eq!(segs, vec!["foo", "bar", "baz/buzz"]); } diff --git a/dropshot/src/server.rs b/dropshot/src/server.rs index ab6135bbf..ce1188ee3 100644 --- a/dropshot/src/server.rs +++ b/dropshot/src/server.rs @@ -206,7 +206,7 @@ impl HttpServerStarter { for (path, method, _) in &app_state.router { debug!(&log, "registered endpoint"; - "method" => &method, + "method" => %method, "path" => &path ); }