@@ -110,6 +110,8 @@ pub struct ClientSession {
110110 pub ( crate ) transaction : Transaction ,
111111 pub ( crate ) snapshot_time : Option < Timestamp > ,
112112 pub ( crate ) operation_time : Option < Timestamp > ,
113+ #[ cfg( test) ]
114+ pub ( crate ) convenient_transaction_timeout : Option < Duration > ,
113115}
114116
115117#[ derive( Debug ) ]
@@ -216,6 +218,8 @@ impl ClientSession {
216218 transaction : Default :: default ( ) ,
217219 snapshot_time : None ,
218220 operation_time : None ,
221+ #[ cfg( test) ]
222+ convenient_transaction_timeout : None ,
219223 }
220224 }
221225
@@ -561,13 +565,117 @@ impl ClientSession {
561565 }
562566 }
563567
568+ /// Starts a transaction, runs the given callback, and commits or aborts the transaction.
569+ /// Transient transaction errors will cause the callback or the commit to be retried;
570+ /// other errors will cause the transaction to be aborted and the error returned to the
571+ /// caller. If the callback needs to provide its own error information, the
572+ /// [`Error::custom`](crate::error::Error::custom) method can accept an arbitrary payload that
573+ /// can be retrieved via [`Error::get_custom`](crate::error::Error::get_custom).
574+ ///
575+ /// Because the callback can be repeatedly executed and because it returns a future, the rust
576+ /// closure borrowing rules for captured values can be overly restrictive. As a
577+ /// convenience, `with_transaction` accepts a context argument that will be passed to the
578+ /// callback along with the session:
579+ ///
580+ /// ```no_run
581+ /// # use mongodb::{bson::{doc, Document}, error::Result, Client};
582+ /// # use futures::FutureExt;
583+ /// # async fn wrapper() -> Result<()> {
584+ /// # let client = Client::with_uri_str("mongodb://example.com").await?;
585+ /// # let mut session = client.start_session(None).await?;
586+ /// let coll = client.database("mydb").collection::<Document>("mycoll");
587+ /// let my_data = "my data".to_string();
588+ /// // This works:
589+ /// session.with_transaction(
590+ /// (&coll, &my_data),
591+ /// |session, (coll, my_data)| async move {
592+ /// coll.insert_one_with_session(doc! { "data": *my_data }, None, session).await
593+ /// }.boxed(),
594+ /// None,
595+ /// ).await?;
596+ /// /* This will not compile with a "variable moved due to use in generator" error:
597+ /// session.with_transaction(
598+ /// (),
599+ /// |session, _| async move {
600+ /// coll.insert_one_with_session(doc! { "data": my_data }, None, session).await
601+ /// }.boxed(),
602+ /// None,
603+ /// ).await?;
604+ /// */
605+ /// # Ok(())
606+ /// # }
607+ /// ```
608+ pub async fn with_transaction < R , C , F > (
609+ & mut self ,
610+ mut context : C ,
611+ mut callback : F ,
612+ options : impl Into < Option < TransactionOptions > > ,
613+ ) -> Result < R >
614+ where
615+ F : for < ' a > FnMut ( & ' a mut ClientSession , & ' a mut C ) -> BoxFuture < ' a , Result < R > > ,
616+ {
617+ let options = options. into ( ) ;
618+ let timeout = Duration :: from_secs ( 120 ) ;
619+ #[ cfg( test) ]
620+ let timeout = self . convenient_transaction_timeout . unwrap_or ( timeout) ;
621+ let start = Instant :: now ( ) ;
622+
623+ use crate :: error:: { TRANSIENT_TRANSACTION_ERROR , UNKNOWN_TRANSACTION_COMMIT_RESULT } ;
624+
625+ ' transaction: loop {
626+ self . start_transaction ( options. clone ( ) ) . await ?;
627+ let ret = match callback ( self , & mut context) . await {
628+ Ok ( v) => v,
629+ Err ( e) => {
630+ if matches ! (
631+ self . transaction. state,
632+ TransactionState :: Starting | TransactionState :: InProgress
633+ ) {
634+ self . abort_transaction ( ) . await ?;
635+ }
636+ if e. contains_label ( TRANSIENT_TRANSACTION_ERROR ) && start. elapsed ( ) < timeout {
637+ continue ' transaction;
638+ }
639+ return Err ( e) ;
640+ }
641+ } ;
642+ if matches ! (
643+ self . transaction. state,
644+ TransactionState :: None
645+ | TransactionState :: Aborted
646+ | TransactionState :: Committed { .. }
647+ ) {
648+ return Ok ( ret) ;
649+ }
650+ ' commit: loop {
651+ match self . commit_transaction ( ) . await {
652+ Ok ( ( ) ) => return Ok ( ret) ,
653+ Err ( e) => {
654+ if e. is_max_time_ms_expired_error ( ) || start. elapsed ( ) >= timeout {
655+ return Err ( e) ;
656+ }
657+ if e. contains_label ( UNKNOWN_TRANSACTION_COMMIT_RESULT ) {
658+ continue ' commit;
659+ }
660+ if e. contains_label ( TRANSIENT_TRANSACTION_ERROR ) {
661+ continue ' transaction;
662+ }
663+ return Err ( e) ;
664+ }
665+ }
666+ }
667+ }
668+ }
669+
564670 fn default_transaction_options ( & self ) -> Option < & TransactionOptions > {
565671 self . options
566672 . as_ref ( )
567673 . and_then ( |options| options. default_transaction_options . as_ref ( ) )
568674 }
569675}
570676
677+ pub type BoxFuture < ' a , T > = std:: pin:: Pin < Box < dyn std:: future:: Future < Output = T > + Send + ' a > > ;
678+
571679struct DroppedClientSession {
572680 cluster_time : Option < ClusterTime > ,
573681 server_session : ServerSession ,
@@ -590,6 +698,8 @@ impl From<DroppedClientSession> for ClientSession {
590698 transaction : dropped_session. transaction ,
591699 snapshot_time : dropped_session. snapshot_time ,
592700 operation_time : dropped_session. operation_time ,
701+ #[ cfg( test) ]
702+ convenient_transaction_timeout : None ,
593703 }
594704 }
595705}
0 commit comments