1+ use std:: hash:: { DefaultHasher , Hash , Hasher } ;
2+
13use futures_util:: FutureExt ;
24use gas:: { db:: WorkflowData , prelude:: * } ;
35use rivet_types:: { keys, runner_configs:: RunnerConfigKind } ;
@@ -12,54 +14,79 @@ pub struct Input {
1214
1315#[ derive( Debug , Serialize , Deserialize , Default ) ]
1416struct LifecycleState {
15- runners : Vec < Id > ,
17+ runners : Vec < RunnerState > ,
18+ }
19+
20+ #[ derive( Debug , Serialize , Deserialize ) ]
21+ struct RunnerState {
22+ /// Serverless runner wf id, not normal runner wf id.
23+ runner_wf_id : Id ,
24+ details_hash : u64 ,
1625}
1726
1827#[ workflow]
1928pub async fn pegboard_serverless_pool ( ctx : & mut WorkflowCtx , input : & Input ) -> Result < ( ) > {
2029 ctx. loope ( LifecycleState :: default ( ) , |ctx, state| {
2130 let input = input. clone ( ) ;
2231 async move {
23- // 1. Remove completed connections
32+ // Get desired count -> drain and start counts
33+ let ReadDesiredOutput :: Desired {
34+ desired_count,
35+ details_hash,
36+ } = ctx. activity ( ReadDesiredInput {
37+ namespace_id : input. namespace_id ,
38+ runner_name : input. runner_name . clone ( ) ,
39+ } )
40+ . await ?
41+ else {
42+ return Ok ( Loop :: Break ( ( ) ) ) ;
43+ } ;
44+
2445 let completed_runners = ctx
2546 . activity ( GetCompletedInput {
26- runners : state. runners . clone ( ) ,
47+ runners : state. runners . iter ( ) . map ( |r| r . runner_wf_id ) . collect ( ) ,
2748 } )
2849 . await ?;
2950
30- state. runners . retain ( |r| !completed_runners. contains ( r) ) ;
31-
32- // 2. Get desired count -> drain and start counts
33- let ReadDesiredOutput :: Desired ( desired_count) = ctx
34- . activity ( ReadDesiredInput {
35- namespace_id : input. namespace_id ,
36- runner_name : input. runner_name . clone ( ) ,
37- } )
38- . await ?
39- else {
40- return Ok ( Loop :: Break ( ( ) ) ) ;
41- } ;
51+ // Remove completed connections
52+ state
53+ . runners
54+ . retain ( |r| !completed_runners. contains ( & r. runner_wf_id ) ) ;
55+
56+ // Remove runners that have an outdated hash. This is done outside of the below draining mechanism
57+ // because we drain specific runners, not just a number of runners
58+ let ( new, outdated) = std:: mem:: take ( & mut state. runners )
59+ . into_iter ( )
60+ . partition :: < Vec < _ > , _ > ( |r| r. details_hash == details_hash) ;
61+ state. runners = new;
62+
63+ for runner in outdated {
64+ ctx. signal ( runner:: Drain { } )
65+ . to_workflow_id ( runner. runner_wf_id )
66+ . send ( )
67+ . await ?;
68+ }
4269
4370 let drain_count = state. runners . len ( ) . saturating_sub ( desired_count) ;
4471 let start_count = desired_count. saturating_sub ( state. runners . len ( ) ) ;
4572
46- // 3. Drain old runners
73+ // Drain unnecessary runners
4774 if drain_count != 0 {
4875 // TODO: Implement smart logic of draining runners with the lowest allocated actors
4976 let draining_runners = state. runners . iter ( ) . take ( drain_count) . collect :: < Vec < _ > > ( ) ;
5077
51- for wf_id in draining_runners {
78+ for runner in draining_runners {
5279 ctx. signal ( runner:: Drain { } )
53- . to_workflow_id ( * wf_id )
80+ . to_workflow_id ( runner . runner_wf_id )
5481 . send ( )
5582 . await ?;
5683 }
5784 }
5885
59- // 4. Dispatch new runner workflows
86+ // Dispatch new runner workflows
6087 if start_count != 0 {
6188 for _ in 0 ..start_count {
62- let wf_id = ctx
89+ let runner_wf_id = ctx
6390 . workflow ( runner:: Input {
6491 pool_wf_id : ctx. workflow_id ( ) ,
6592 namespace_id : input. namespace_id ,
@@ -70,14 +97,17 @@ pub async fn pegboard_serverless_pool(ctx: &mut WorkflowCtx, input: &Input) -> R
7097 . dispatch ( )
7198 . await ?;
7299
73- state. runners . push ( wf_id) ;
100+ state. runners . push ( RunnerState {
101+ runner_wf_id,
102+ details_hash,
103+ } ) ;
74104 }
75105 }
76106
77107 // Wait for Bump or runner update signals until we tick again
78108 match ctx. listen :: < Main > ( ) . await ? {
79109 Main :: RunnerDrainStarted ( sig) => {
80- state. runners . retain ( |wf_id| * wf_id != sig. runner_wf_id ) ;
110+ state. runners . retain ( |r| r . runner_wf_id != sig. runner_wf_id ) ;
81111 }
82112 Main :: Bump ( _) => { }
83113 }
@@ -102,6 +132,7 @@ async fn get_completed(ctx: &ActivityCtx, input: &GetCompletedInput) -> Result<V
102132 . get_workflows ( input. runners . clone ( ) )
103133 . await ?
104134 . into_iter ( )
135+ // When a workflow has output, it means it has completed
105136 . filter ( WorkflowData :: has_output)
106137 . map ( |wf| wf. workflow_id )
107138 . collect ( ) )
@@ -115,7 +146,10 @@ struct ReadDesiredInput {
115146
116147#[ derive( Debug , Serialize , Deserialize ) ]
117148enum ReadDesiredOutput {
118- Desired ( usize ) ,
149+ Desired {
150+ desired_count : usize ,
151+ details_hash : u64 ,
152+ } ,
119153 Stop ,
120154}
121155
@@ -132,6 +166,9 @@ async fn read_desired(ctx: &ActivityCtx, input: &ReadDesiredInput) -> Result<Rea
132166 } ;
133167
134168 let RunnerConfigKind :: Serverless {
169+ url,
170+ headers,
171+
135172 slots_per_runner,
136173 min_runners,
137174 max_runners,
@@ -177,7 +214,18 @@ async fn read_desired(ctx: &ActivityCtx, input: &ReadDesiredInput) -> Result<Rea
177214 . min ( max_runners)
178215 . try_into ( ) ?;
179216
180- Ok ( ReadDesiredOutput :: Desired ( desired_count) )
217+ // Compute consistent hash of serverless details
218+ let mut hasher = DefaultHasher :: new ( ) ;
219+ url. hash ( & mut hasher) ;
220+ let mut sorted_headers = headers. iter ( ) . collect :: < Vec < _ > > ( ) ;
221+ sorted_headers. sort ( ) ;
222+ sorted_headers. hash ( & mut hasher) ;
223+ let details_hash = hasher. finish ( ) ;
224+
225+ Ok ( ReadDesiredOutput :: Desired {
226+ desired_count,
227+ details_hash,
228+ } )
181229}
182230
183231#[ signal( "pegboard_serverless_bump" ) ]
0 commit comments