Skip to content

Commit a9fa049

Browse files
committed
fix(graph, node): add option to authenticate Flight service requests
1 parent ef5c9a9 commit a9fa049

File tree

5 files changed

+42
-11
lines changed

5 files changed

+42
-11
lines changed

graph/src/amp/client/flight_client.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use crate::{
3333
/// using the Apache Arrow Flight protocol.
3434
pub struct FlightClient {
3535
channel: Channel,
36+
auth_token: Option<String>,
3637
}
3738

3839
impl FlightClient {
@@ -56,16 +57,27 @@ impl FlightClient {
5657

5758
Ok(Self {
5859
channel: endpoint.connect().await.map_err(Error::Connection)?,
60+
auth_token: None,
5961
})
6062
}
6163

64+
/// Sets the authentication token for requests to the Amp server.
65+
pub fn set_auth_token(&mut self, auth_token: impl Into<String>) {
66+
self.auth_token = Some(auth_token.into());
67+
}
68+
6269
fn raw_client(&self) -> FlightSqlServiceClient<Channel> {
6370
let channel = self.channel.cheap_clone();
6471
let client = FlightServiceClient::new(channel)
6572
.max_encoding_message_size(256 * 1024 * 1024)
6673
.max_decoding_message_size(256 * 1024 * 1024);
6774

68-
FlightSqlServiceClient::new_from_inner(client)
75+
let mut client = FlightSqlServiceClient::new_from_inner(client);
76+
if let Some(auth_token) = &self.auth_token {
77+
client.set_token(auth_token.clone());
78+
}
79+
80+
client
6981
}
7082
}
7183

graph/src/env/amp.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ pub struct AmpEnv {
2424
///
2525
/// Defaults to `600` seconds.
2626
pub query_retry_max_delay: Duration,
27+
28+
/// Token used to authenticate Amp Flight gRPC service requests.
29+
///
30+
/// Defaults to `None`.
31+
pub flight_service_token: Option<String>,
2732
}
2833

2934
impl AmpEnv {
@@ -60,6 +65,12 @@ impl AmpEnv {
6065
.amp_query_retry_max_delay_seconds
6166
.map(Duration::from_secs)
6267
.unwrap_or(Self::DEFAULT_QUERY_RETRY_MAX_DELAY),
68+
flight_service_token: raw_env.amp_flight_service_token.as_ref().and_then(|value| {
69+
if value.is_empty() {
70+
return None;
71+
}
72+
Some(value.to_string())
73+
}),
6374
}
6475
}
6576
}

graph/src/env/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,8 @@ struct Inner {
561561
amp_query_retry_min_delay_seconds: Option<u64>,
562562
#[envconfig(from = "GRAPH_AMP_QUERY_RETRY_MAX_DELAY_SECONDS")]
563563
amp_query_retry_max_delay_seconds: Option<u64>,
564+
#[envconfig(from = "GRAPH_AMP_FLIGHT_SERVICE_TOKEN")]
565+
amp_flight_service_token: Option<String>,
564566
}
565567

566568
#[derive(Clone, Debug)]

node/src/main.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -385,12 +385,15 @@ async fn main_inner() {
385385
.parse()
386386
.expect("Invalid Amp Flight service address");
387387

388-
let amp_client = Arc::new(
389-
amp::FlightClient::new(addr)
390-
.await
391-
.expect("Failed to connect to Amp Flight service"),
392-
);
388+
let mut amp_client = amp::FlightClient::new(addr)
389+
.await
390+
.expect("Failed to connect to Amp Flight service");
391+
392+
if let Some(auth_token) = &env_vars.amp.flight_service_token {
393+
amp_client.set_auth_token(auth_token);
394+
}
393395

396+
let amp_client = Arc::new(amp_client);
394397
let amp_instance_manager = graph_core::amp_subgraph::Manager::new(
395398
&logger_factory,
396399
metrics_registry.cheap_clone(),

node/src/manager/commands/run.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,15 @@ pub async fn run(
156156
.parse()
157157
.expect("Invalid Amp Flight service address");
158158

159-
let amp_client = Arc::new(
160-
amp::FlightClient::new(addr)
161-
.await
162-
.expect("Failed to connect to Amp Flight service"),
163-
);
159+
let mut amp_client = amp::FlightClient::new(addr)
160+
.await
161+
.expect("Failed to connect to Amp Flight service");
162+
163+
if let Some(auth_token) = &env_vars.amp.flight_service_token {
164+
amp_client.set_auth_token(auth_token);
165+
}
164166

167+
let amp_client = Arc::new(amp_client);
165168
let amp_instance_manager = graph_core::amp_subgraph::Manager::new(
166169
&logger_factory,
167170
metrics_registry.cheap_clone(),

0 commit comments

Comments
 (0)