Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 56 additions & 42 deletions core/src/main/scala/kafka/server/DelayedFetch.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package kafka.server

import com.yammer.metrics.core.Meter
import io.aiven.inkless.control_plane.FindBatchRequest
import io.aiven.inkless.control_plane.FindBatchResponse
import kafka.utils.Logging

import java.util.concurrent.TimeUnit
Expand Down Expand Up @@ -60,6 +60,7 @@ class DelayedFetch(
minBytes: Option[Int] = None,
responseCallback: Seq[(TopicIdPartition, FetchPartitionData)] => Unit,
) extends DelayedOperation(maxWaitMs.getOrElse(params.maxWaitMs)) with Logging {
var maybeBatchCoordinates: Option[Map[TopicIdPartition, FindBatchResponse]] = None

override def toString: String = {
s"DelayedFetch(params=$params" +
Expand Down Expand Up @@ -153,7 +154,15 @@ class DelayedFetch(
}
}

tryCompleteDiskless(disklessFetchPartitionStatus) match {
// adjust the max bytes for diskless fetches based on the percentage of diskless partitions
// Complete the classic fetches first
val classicRequestsSize = classicFetchPartitionStatus.size.toFloat
val disklessRequestsSize = disklessFetchPartitionStatus.size.toFloat
val totalRequestsSize = classicRequestsSize + disklessRequestsSize
val disklessPercentage = disklessRequestsSize / totalRequestsSize
val disklessParams = replicaManager.fetchParamsWithNewMaxBytes(params, disklessPercentage)

tryCompleteDiskless(disklessFetchPartitionStatus, disklessParams.maxBytes) match {
case Some(disklessAccumulatedSize) => accumulatedSize += disklessAccumulatedSize
case None => forceComplete()
}
Expand All @@ -174,53 +183,55 @@ class DelayedFetch(
* Case D: The fetch offset is equal to the end offset, meaning that we have reached the end of the log
* Upon completion, should return whatever data is available for each valid partition
*/
private def tryCompleteDiskless(fetchPartitionStatus: Seq[(TopicIdPartition, FetchPartitionStatus)]): Option[Long] = {
private def tryCompleteDiskless(
fetchPartitionStatus: Seq[(TopicIdPartition, FetchPartitionStatus)],
disklessMaxBytes: Int
): Option[Long] = {
var accumulatedSize = 0L
val fetchPartitionStatusMap = fetchPartitionStatus.toMap
val requests = fetchPartitionStatus.map { case (topicIdPartition, fetchStatus) =>
new FindBatchRequest(topicIdPartition, fetchStatus.startOffsetMetadata.messageOffset, fetchStatus.fetchInfo.maxBytes)
}
if (requests.isEmpty) return Some(0)

val response = try {
replicaManager.findDisklessBatches(requests, Int.MaxValue)
maybeBatchCoordinates = try {
Some(replicaManager.findDisklessBatches(fetchPartitionStatus, disklessMaxBytes))
} catch {
case e: Throwable =>
error("Error while trying to find diskless batches on delayed fetch.", e)
return None // Case C
}

response.get.asScala.foreach { r =>
r.errors() match {
case Errors.NONE =>
if (r.batches().size() > 0) {
// Gather topic id partition from first batch. Same for all batches in the response.
val topicIdPartition = r.batches().get(0).metadata().topicIdPartition()
val endOffset = r.highWatermark()

val fetchPartitionStatus = fetchPartitionStatusMap.get(topicIdPartition)
if (fetchPartitionStatus.isEmpty) {
warn(s"Fetch partition status for $topicIdPartition not found in delayed fetch $this.")
return None // Case C
}

val fetchOffset = fetchPartitionStatus.get.startOffsetMetadata
// If the fetch offset is greater than the end offset, it means that the log has been truncated
// If it is equal to the end offset, it means that we have reached the end of the log
// If the fetch offset is less than the end offset, we can accumulate the size of the batches
if (fetchOffset.messageOffset > endOffset) {
// Truncation happened
debug(s"Satisfying fetch $this since it is fetching later segments of partition $topicIdPartition.")
return None // Case A
} else if (fetchOffset.messageOffset < endOffset) {
val bytesAvailable = r.estimatedByteSize(fetchOffset.messageOffset)
accumulatedSize += bytesAvailable // Case B: accumulate the size of the batches
} // Case D: same as fetchOffset == endOffset, no new data available
maybeBatchCoordinates match {
case Some(exists) =>
exists.values.foreach { r =>
r.errors() match {
case Errors.NONE =>
if (r.batches().size() > 0) {
// Gather topic id partition from first batch. Same for all batches in the response.
val topicIdPartition = r.batches().get(0).metadata().topicIdPartition()
val endOffset = r.highWatermark()

val fetchPartitionStatus = fetchPartitionStatusMap.get(topicIdPartition)
if (fetchPartitionStatus.isEmpty) {
warn(s"Fetch partition status for $topicIdPartition not found in delayed fetch $this.")
return None // Case C
}

val fetchOffset = fetchPartitionStatus.get.startOffsetMetadata
// If the fetch offset is greater than the end offset, it means that the log has been truncated
// If it is equal to the end offset, it means that we have reached the end of the log
// If the fetch offset is less than the end offset, we can accumulate the size of the batches
if (fetchOffset.messageOffset > endOffset) {
// Truncation happened
debug(s"Satisfying fetch $this since it is fetching later segments of partition $topicIdPartition.")
return None // Case A
} else if (fetchOffset.messageOffset < endOffset) {
val bytesAvailable = r.estimatedByteSize(fetchOffset.messageOffset)
accumulatedSize += bytesAvailable // Case B: accumulate the size of the batches
} // Case D: same as fetchOffset == endOffset, no new data available
}
case _ => return None // Case C
}
case _ => return None // Case C
}
}
case None => // Case D
}

Some(accumulatedSize)
}

Expand Down Expand Up @@ -272,13 +283,16 @@ class DelayedFetch(

if (disklessRequestsSize > 0) {
// Classic fetches are complete, now handle diskless fetches
// adjust the max bytes for diskless fetches based on the percentage of diskless partitions
val disklessPercentage = disklessRequestsSize / totalRequestsSize
val disklessParams = replicaManager.fetchParamsWithNewMaxBytes(params, disklessPercentage)
val disklessFetchInfos = disklessFetchPartitionStatus.map { case (tp, status) =>
tp -> status.fetchInfo
}
val disklessFetchResponseFuture = replicaManager.fetchDisklessMessages(disklessParams, disklessFetchInfos)
val batchCoordinates = maybeBatchCoordinates match {
case Some(batchCoordinates) => batchCoordinates
case None =>
responseCallback(Seq.empty)
return
}
val disklessFetchResponseFuture = replicaManager.fetchDisklessMessages(batchCoordinates, disklessFetchInfos)

// Combine the classic fetch results with the diskless fetch results
disklessFetchResponseFuture.whenComplete { case (disklessFetchPartitionData, _) =>
Expand Down
52 changes: 45 additions & 7 deletions core/src/main/scala/kafka/server/ReplicaManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1719,16 +1719,40 @@ class ReplicaManager(val config: KafkaConfig,
}
}

def findDisklessBatches(requests: Seq[FindBatchRequest], maxBytes: Int): Option[util.List[FindBatchResponse]] = {
inklessSharedState.map { sharedState =>
sharedState.controlPlane().findBatches(requests.asJava, maxBytes, sharedState.config().maxBatchesPerPartitionToFind())
def findDisklessBatches(fetchPartitionStatus: Seq[(TopicIdPartition, FetchPartitionStatus)], maxBytes: Int): Map[TopicIdPartition, FindBatchResponse] = {
val requests = fetchPartitionStatus.map { case (topicIdPartition, fetchStatus) =>
new FindBatchRequest(topicIdPartition, fetchStatus.startOffsetMetadata.messageOffset, fetchStatus.fetchInfo.maxBytes)
}
if (requests.isEmpty) return Map.empty

val findBatchResponses = try {
inklessSharedState.map { sharedState =>
sharedState.controlPlane().findBatches(requests.asJava, maxBytes, sharedState.config().maxBatchesPerPartitionToFind())
}
} match {
case Some(responses) => responses
case None =>
return Map.empty
} catch {
case e: Throwable =>
// kala
trace("Error while trying to find diskless batches.", e)
return Map.empty
}

val topicPartitionToFindBatchResponse = collection.mutable.Map[TopicIdPartition, FindBatchResponse]()
for (i <- requests.indices) {
val request = requests(i)
val response = findBatchResponses.get(i)
topicPartitionToFindBatchResponse.update(request.topicIdPartition, response)
}
topicPartitionToFindBatchResponse;
}

def fetchDisklessMessages(params: FetchParams,
def fetchDisklessMessages(batchCoordinates: Map[TopicIdPartition, FindBatchResponse],
fetchInfos: Seq[(TopicIdPartition, PartitionData)]): CompletableFuture[Seq[(TopicIdPartition, FetchPartitionData)]] = {
inklessFetchHandler match {
case Some(handler) => handler.handle(params, fetchInfos.toMap.asJava).thenApply(_.asScala.toSeq)
case Some(handler) => handler.handle(batchCoordinates.asJava, fetchInfos.toMap.asJava).thenApply(_.asScala.toSeq)
case None =>
if (fetchInfos.nonEmpty)
error(s"Received diskless fetch request for topics ${fetchInfos.map(_._1.topic()).distinct.mkString(", ")} but diskless fetch handler is not available. " +
Expand Down Expand Up @@ -1830,6 +1854,8 @@ class ReplicaManager(val config: KafkaConfig,
delayedFetchPurgatory.tryCompleteElseWatch(delayedFetch, (classicDelayedFetchKeys ++ disklessDelayedFetchKeys).asJava)
}

// If there is nothing to fetch for classic topics,
// create delayed response and fetch possible diskless data there.
if (classicFetchInfos.isEmpty) {
delayedResponse(Seq.empty)
return
Expand Down Expand Up @@ -1894,9 +1920,18 @@ class ReplicaManager(val config: KafkaConfig,
// In case of remote fetches, synchronously wait for diskless records and then perform the remote fetch.
// This is currently a workaround to avoid modifying the DelayedRemoteFetch in order to correctly process
// diskless fetches.
// Get diskless batch coordinates and hand over to fetching
val batchCoordinates = try {
findDisklessBatches(fetchPartitionStatus, Int.MaxValue)
} catch {
case e: Throwable =>
error("Error while trying to find diskless batches on remote fetch.", e)
responseCallback(Seq.empty)
return
}

val disklessFetchResults = try {
val disklessParams = fetchParamsWithNewMaxBytes(params, disklessFetchInfos.size.toFloat / fetchInfos.size.toFloat)
val disklessResponsesFuture = fetchDisklessMessages(disklessParams, disklessFetchInfos)
val disklessResponsesFuture = fetchDisklessMessages(batchCoordinates, disklessFetchInfos)

val response = disklessResponsesFuture.get(maxWaitMs, TimeUnit.MILLISECONDS)
response.map { case (tp, data) =>
Expand Down Expand Up @@ -1933,8 +1968,11 @@ class ReplicaManager(val config: KafkaConfig,
}
} else {
if (disklessFetchInfos.isEmpty && (bytesReadable >= params.minBytes || params.maxWaitMs <= 0)) {
// No remote fetch needed and not any diskless topics to be fetched.
// Response immediately.
responseCallback(fetchPartitionData)
} else {
// No remote fetch, requires fetching data from the diskless topics.
delayedResponse(fetchPartitionStatus)
}
}
Expand Down
36 changes: 21 additions & 15 deletions core/src/test/scala/integration/kafka/server/DelayedFetchTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package kafka.server

import io.aiven.inkless.control_plane.{BatchInfo, BatchMetadata, FindBatchRequest, FindBatchResponse}
import io.aiven.inkless.control_plane.{BatchInfo, BatchMetadata, FindBatchResponse}

import java.util.{Collections, Optional, OptionalLong}
import scala.collection.Seq
Expand Down Expand Up @@ -213,6 +213,9 @@ class DelayedFetchTest {
responseCallback = callback
)

val batchCoordinates = Map.empty[TopicIdPartition, FindBatchResponse]
when(replicaManager.findDisklessBatches(any(), anyInt())).thenReturn(batchCoordinates)

val partition: Partition = mock(classOf[Partition])
when(replicaManager.getPartitionOrException(topicIdPartition.topicPartition)).thenReturn(partition)
// Note that the high-watermark does not contain the complete metadata
Expand Down Expand Up @@ -345,12 +348,13 @@ class DelayedFetchTest {
)))
when(mockResponse.highWatermark()).thenReturn(endOffset) // endOffset < fetchOffset (truncation)

val future = Some(Collections.singletonList(mockResponse))
when(replicaManager.findDisklessBatches(any[Seq[FindBatchRequest]], anyInt())).thenReturn(future)
val batchCoordinates = Map((topicIdPartition, mockResponse))

when(replicaManager.findDisklessBatches(any(), anyInt())).thenReturn(batchCoordinates)

// Mock fetchDisklessMessages for onComplete
when(replicaManager.fetchParamsWithNewMaxBytes(any[FetchParams], any[Float])).thenAnswer(_.getArgument(0))
when(replicaManager.fetchDisklessMessages(any[FetchParams], any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]]))
when(replicaManager.fetchDisklessMessages(any[Map[TopicIdPartition, FindBatchResponse]], any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]]))
.thenReturn(CompletableFuture.completedFuture(Seq((topicIdPartition, mock(classOf[FetchPartitionData])))))

when(replicaManager.readFromLog(
Expand Down Expand Up @@ -402,6 +406,7 @@ class DelayedFetchTest {
fetchResultOpt = Some(responses)
}

when(replicaManager.fetchParamsWithNewMaxBytes(any(), any())).thenReturn(fetchParams)
val delayedFetch = new DelayedFetch(
params = fetchParams,
classicFetchPartitionStatus = Seq.empty,
Expand Down Expand Up @@ -434,8 +439,8 @@ class DelayedFetchTest {
when(mockResponse.highWatermark()).thenReturn(fetchOffset) // fetchOffset == endOffset (no new data)
when(mockResponse.estimatedByteSize(fetchOffset)).thenReturn(estimatedBatchSize)

val future = Some(Collections.singletonList(mockResponse))
when(replicaManager.findDisklessBatches(any[Seq[FindBatchRequest]], anyInt())).thenReturn(future)
val future = Map((topicIdPartition, mockResponse))
when(replicaManager.findDisklessBatches(any[Seq[(TopicIdPartition, FetchPartitionStatus)]], anyInt())).thenReturn(future)

when(replicaManager.readFromLog(
fetchParams,
Expand All @@ -451,7 +456,7 @@ class DelayedFetchTest {
assertFalse(fetchResultOpt.isDefined)

// Verify that estimatedByteSize is never called since fetchOffset == endOffset
verify(replicaManager, never()).fetchDisklessMessages(any[FetchParams], any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]])
verify(replicaManager, never()).fetchDisklessMessages(any[Map[TopicIdPartition, FindBatchResponse]], any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]])
verify(mockResponse, never()).estimatedByteSize(anyLong())
}

Expand Down Expand Up @@ -487,6 +492,7 @@ class DelayedFetchTest {
fetchResultOpt = Some(responses)
}

when(replicaManager.fetchParamsWithNewMaxBytes(any(), any())).thenReturn(fetchParams)
val delayedFetch = new DelayedFetch(
params = fetchParams,
classicFetchPartitionStatus = Seq.empty,
Expand Down Expand Up @@ -519,8 +525,8 @@ class DelayedFetchTest {
when(mockResponse.highWatermark()).thenReturn(endOffset) // endOffset > fetchOffset (data available)
when(mockResponse.estimatedByteSize(fetchOffset)).thenReturn(estimatedBatchSize)

val future = Some(Collections.singletonList(mockResponse))
when(replicaManager.findDisklessBatches(any[Seq[FindBatchRequest]], anyInt())).thenReturn(future)
val batchCoordinates = Map((topicIdPartition, mockResponse))
when(replicaManager.findDisklessBatches(any(), anyInt())).thenReturn(batchCoordinates)

when(replicaManager.readFromLog(
fetchParams,
Expand Down Expand Up @@ -601,12 +607,12 @@ class DelayedFetchTest {
when(mockResponse.highWatermark()).thenReturn(endOffset) // endOffset > fetchOffset (data available)
when(mockResponse.estimatedByteSize(fetchOffset)).thenReturn(estimatedBatchSize)

val future = Some(Collections.singletonList(mockResponse))
when(replicaManager.findDisklessBatches(any[Seq[FindBatchRequest]], anyInt())).thenReturn(future)
val batchCoordinates = Map((topicIdPartition, mockResponse))
when(replicaManager.findDisklessBatches(any(), anyInt())).thenReturn(batchCoordinates)

// Mock fetchDisklessMessages for onComplete
when(replicaManager.fetchParamsWithNewMaxBytes(any[FetchParams], anyFloat())).thenAnswer(_.getArgument(0))
when(replicaManager.fetchDisklessMessages(any[FetchParams], any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]]))
when(replicaManager.fetchDisklessMessages(any[Map[TopicIdPartition, FindBatchResponse]], any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]]))
.thenReturn(CompletableFuture.completedFuture(Seq((topicIdPartition, mock(classOf[FetchPartitionData])))))

when(replicaManager.readFromLog(
Expand Down Expand Up @@ -685,12 +691,12 @@ class DelayedFetchTest {
)))
when(mockResponse.highWatermark()).thenReturn(600L)

val future = Some(Collections.singletonList(mockResponse))
when(replicaManager.findDisklessBatches(any[Seq[FindBatchRequest]], anyInt())).thenReturn(future)
val batchCoordinates = Map((topicIdPartition, mockResponse))
when(replicaManager.findDisklessBatches(any(), anyInt())).thenReturn(batchCoordinates)

// Mock fetchDisklessMessages for onComplete
when(replicaManager.fetchParamsWithNewMaxBytes(any[FetchParams], anyFloat())).thenAnswer(_.getArgument(0))
when(replicaManager.fetchDisklessMessages(any[FetchParams], any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]]))
when(replicaManager.fetchDisklessMessages(any[Map[TopicIdPartition, FindBatchResponse]], any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]]))
.thenReturn(CompletableFuture.completedFuture(Seq((topicIdPartition, mock(classOf[FetchPartitionData])))))

when(replicaManager.readFromLog(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7098,9 +7098,10 @@ class ReplicaManagerTest {
// and the response does not satisfy minBytes, it should be delayed in the purgatory
// until the delayed fetch expires.
replicaManager.fetchMessages(fetchParams, fetchInfos, QuotaFactory.UNBOUNDED_QUOTA, responseCallback)
assertEquals(0, replicaManager.delayedFetchPurgatory.numDelayed())
assertEquals(1, replicaManager.delayedFetchPurgatory.numDelayed())

latch.await(10, TimeUnit.SECONDS) // Wait for the delayed fetch to expire
assertEquals(0, replicaManager.delayedFetchPurgatory.numDelayed())
assertNotNull(responseData)
assertEquals(2, responseData.size)
assertEquals(disklessResponse(disklessTopicPartition), responseData(disklessTopicPartition))
Expand Down
Loading
Loading