Skip to content

Commit 459028c

Browse files
committed
fix(graph, node): add option to authenticate Flight service requests
1 parent e4d71e8 commit 459028c

File tree

5 files changed

+39
-7
lines changed

5 files changed

+39
-7
lines changed

graph/src/amp/client/flight_client.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use crate::{
3333
/// using the Apache Arrow Flight protocol.
3434
pub struct FlightClient {
3535
channel: Channel,
36+
auth_token: Option<String>,
3637
}
3738

3839
impl FlightClient {
@@ -56,16 +57,27 @@ impl FlightClient {
5657

5758
Ok(Self {
5859
channel: endpoint.connect().await.map_err(Error::Connection)?,
60+
auth_token: None,
5961
})
6062
}
6163

64+
/// Sets the authentication token for requests to the Amp server.
65+
pub fn set_auth_token(&mut self, auth_token: impl Into<String>) {
66+
self.auth_token = Some(auth_token.into());
67+
}
68+
6269
fn raw_client(&self) -> FlightSqlServiceClient<Channel> {
6370
let channel = self.channel.cheap_clone();
6471
let client = FlightServiceClient::new(channel)
6572
.max_encoding_message_size(256 * 1024 * 1024)
6673
.max_decoding_message_size(256 * 1024 * 1024);
6774

68-
FlightSqlServiceClient::new_from_inner(client)
75+
let mut client = FlightSqlServiceClient::new_from_inner(client);
76+
if let Some(auth_token) = &self.auth_token {
77+
client.set_token(auth_token.clone());
78+
}
79+
80+
client
6981
}
7082
}
7183

graph/src/env/amp.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ pub struct AmpEnv {
2424
///
2525
/// Defaults to `600` seconds.
2626
pub query_retry_max_delay: Duration,
27+
28+
/// Token used to authenticate Amp Flight gRPC service requests.
29+
///
30+
/// Defaults to `None`.
31+
pub flight_service_token: Option<String>,
2732
}
2833

2934
impl AmpEnv {
@@ -60,6 +65,12 @@ impl AmpEnv {
6065
.amp_query_retry_max_delay_seconds
6166
.map(Duration::from_secs)
6267
.unwrap_or(Self::DEFAULT_QUERY_RETRY_MAX_DELAY),
68+
flight_service_token: raw_env.amp_flight_service_token.as_ref().and_then(|value| {
69+
if value.is_empty() {
70+
return None;
71+
}
72+
Some(value.to_string())
73+
}),
6374
}
6475
}
6576
}

graph/src/env/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,8 @@ struct Inner {
602602
amp_query_retry_min_delay_seconds: Option<u64>,
603603
#[envconfig(from = "GRAPH_AMP_QUERY_RETRY_MAX_DELAY_SECONDS")]
604604
amp_query_retry_max_delay_seconds: Option<u64>,
605+
#[envconfig(from = "GRAPH_AMP_FLIGHT_SERVICE_TOKEN")]
606+
amp_flight_service_token: Option<String>,
605607
}
606608

607609
#[derive(Clone, Debug)]

node/src/launcher.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,10 +505,14 @@ pub async fn run(
505505
.parse()
506506
.expect("Invalid Amp Flight service address");
507507

508-
let amp_client = amp::FlightClient::new(addr)
508+
let mut amp_client = amp::FlightClient::new(addr)
509509
.await
510510
.expect("Failed to connect to Amp Flight service");
511511

512+
if let Some(auth_token) = &env_vars.amp.flight_service_token {
513+
amp_client.set_auth_token(auth_token);
514+
}
515+
512516
Some(Arc::new(amp_client))
513517
}
514518
None => None,

node/src/manager/commands/run.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,15 @@ pub async fn run(
150150
.parse()
151151
.expect("Invalid Amp Flight service address");
152152

153-
let amp_client = Arc::new(
154-
amp::FlightClient::new(addr)
155-
.await
156-
.expect("Failed to connect to Amp Flight service"),
157-
);
153+
let mut amp_client = amp::FlightClient::new(addr)
154+
.await
155+
.expect("Failed to connect to Amp Flight service");
156+
157+
if let Some(auth_token) = &env_vars.amp.flight_service_token {
158+
amp_client.set_auth_token(auth_token);
159+
}
158160

161+
let amp_client = Arc::new(amp_client);
159162
let amp_instance_manager = graph_core::amp_subgraph::Manager::new(
160163
&logger_factory,
161164
metrics_registry.cheap_clone(),

0 commit comments

Comments
 (0)