diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/mongodb/MongoSequenceIncrementer.java b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/mongodb/MongoSequenceIncrementer.java index 9722db637f..43ad9dfb54 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/mongodb/MongoSequenceIncrementer.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/mongodb/MongoSequenceIncrementer.java @@ -19,7 +19,11 @@ import com.mongodb.client.model.ReturnDocument; import org.bson.Document; +import org.springframework.core.retry.RetryException; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; import org.springframework.dao.DataAccessException; +import org.springframework.dao.DataIntegrityViolationException; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.jdbc.support.incrementer.DataFieldMaxValueIncrementer; @@ -29,10 +33,14 @@ /** * @author Mahmoud Ben Hassine * @author Christoph Strobl + * @author Yanming Zhou * @since 5.2.0 */ public class MongoSequenceIncrementer implements DataFieldMaxValueIncrementer { + private final RetryTemplate retryTemplate = new RetryTemplate( + RetryPolicy.builder().includes(DataIntegrityViolationException.class).build()); + private final MongoOperations mongoTemplate; private final String sequenceName; @@ -44,11 +52,22 @@ public MongoSequenceIncrementer(MongoOperations mongoTemplate, String sequenceNa @Override public long nextLongValue() throws DataAccessException { - return mongoTemplate.execute("BATCH_SEQUENCES", - collection -> collection + try { + return retryTemplate + .execute(() -> mongoTemplate.execute("BATCH_SEQUENCES", collection -> collection .findOneAndUpdate(new Document("_id", sequenceName), new Document("$inc", new Document("count", 1)), new FindOneAndUpdateOptions().returnDocument(ReturnDocument.AFTER)) - .getLong("count")); + .getLong("count"))); + } + catch (RetryException e) { + Throwable cause = e.getCause(); + if (cause instanceof DataAccessException ex) { + throw ex; + } + else { + throw new RuntimeException("Failed to retrieve next value of sequence", e); + } + } } @Override