Skip to content

Commit b6389e2

Browse files
mustansir14amanfcp
andauthored
[INS-104] Support units in S3 source (#4560)
* implemented Source unit for S3. Implemented. integration test * use bucket as source unit * remove code duplication, reuse from Chunks * remove unnecessary change * remove unused functions * revisit tests * revert unnecessary change * change SourceUnitKind to s3_bucket * handle nil objectCount inside scanBucket * handle nil objectCount outside loop * add bucket to resume log * add bucket and role to error log, remove enumerating log * implement sub unit resumption * add comment to checkpointer for unit scans * implement SourceUnitUnmarshaller on source with the new S3SourceUnit, add test to test resumption on multiple buckets with concurrent ChunkUnit processing * add role to SourceUnitID * Revert "add role to SourceUnitID" This reverts commit 549e6be. * add role to source unit ID, keep track of resumption using source unit ID instead of just bucket name * rename bucket -> unitID in UnmarshalSourceUnit --------- Co-authored-by: Amaan Ullah <aman.ullah.jalal@trufflesec.com>
1 parent 1935692 commit b6389e2

File tree

8 files changed

+542
-72
lines changed

8 files changed

+542
-72
lines changed

pkg/sources/s3/checkpointer.go

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import (
1111
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
1212
)
1313

14+
// TODO [INS-207] Add role to legacy scan resumption info
15+
1416
// Checkpointer maintains resumption state for S3 bucket scanning,
1517
// enabling resumable scans by tracking which objects have been successfully processed.
1618
// It provides checkpoints that can be used to resume interrupted scans without missing objects.
@@ -33,6 +35,10 @@ import (
3335
// resuming from the correct bucket. The scan will continue from the last checkpointed object
3436
// in that bucket.
3537
//
38+
// Unit scans are also supported. The encoded resume info in this case tracks the last processed object
39+
// for each unit separately by using the SetEncodedResumeInfoFor method on Progress. To use the
40+
// checkpointer for unit scans, call SetIsUnitScan(true) before starting the scan.
41+
//
3642
// For example, if scanning is interrupted after processing 1500 objects across 2 pages:
3743
// Page 1 (objects 0-999): Fully processed, checkpoint saved at object 999
3844
// Page 2 (objects 1000-1999): Partially processed through 1600, but only consecutive through 1499
@@ -56,6 +62,8 @@ type Checkpointer struct {
5662
// progress holds the scan's overall progress state and enables persistence.
5763
// The EncodedResumeInfo field stores the JSON-encoded ResumeInfo checkpoint.
5864
progress *sources.Progress // Reference to source's Progress
65+
66+
isUnitScan bool // Indicates if scanning is done in unit scan mode
5967
}
6068

6169
const defaultMaxObjectsPerPage = 1000
@@ -153,9 +161,10 @@ func (p *Checkpointer) UpdateObjectCompletion(
153161
ctx context.Context,
154162
completedIdx int,
155163
bucket string,
164+
role string,
156165
pageContents []s3types.Object,
157166
) error {
158-
ctx = context.WithValues(ctx, "bucket", bucket, "completedIdx", completedIdx)
167+
ctx = context.WithValues(ctx, "bucket", bucket, "role", role, "completedIdx", completedIdx)
159168
ctx.Logger().V(5).Info("Updating progress")
160169

161170
if completedIdx >= len(p.completedObjects) {
@@ -184,7 +193,7 @@ func (p *Checkpointer) UpdateObjectCompletion(
184193
}
185194
obj := pageContents[checkpointIdx]
186195

187-
return p.updateCheckpoint(bucket, *obj.Key)
196+
return p.updateCheckpoint(bucket, role, *obj.Key)
188197
}
189198

190199
// advanceLowestIncompleteIdx moves the lowest incomplete index forward to the next incomplete object.
@@ -198,7 +207,14 @@ func (p *Checkpointer) advanceLowestIncompleteIdx() {
198207

199208
// updateCheckpoint persists the current resumption state.
200209
// Must be called with lock held.
201-
func (p *Checkpointer) updateCheckpoint(bucket string, lastKey string) error {
210+
func (p *Checkpointer) updateCheckpoint(bucket string, role string, lastKey string) error {
211+
if p.isUnitScan {
212+
unitID := constructS3SourceUnitID(bucket, role)
213+
// track sub-unit resumption state
214+
p.progress.SetEncodedResumeInfoFor(unitID, lastKey)
215+
return nil
216+
}
217+
202218
encoded, err := json.Marshal(&ResumeInfo{CurrentBucket: bucket, StartAfter: lastKey})
203219
if err != nil {
204220
return fmt.Errorf("failed to encode resume info: %w", err)
@@ -212,3 +228,11 @@ func (p *Checkpointer) updateCheckpoint(bucket string, lastKey string) error {
212228
)
213229
return nil
214230
}
231+
232+
// SetIsUnitScan sets whether the checkpointer is operating in unit scan mode.
233+
func (p *Checkpointer) SetIsUnitScan(isUnitScan bool) {
234+
p.mu.Lock()
235+
defer p.mu.Unlock()
236+
237+
p.isUnitScan = isUnitScan
238+
}

pkg/sources/s3/checkpointer_test.go

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func TestCheckpointerResumption(t *testing.T) {
3131

3232
// Process first 6 objects.
3333
for i := range 6 {
34-
err := tracker.UpdateObjectCompletion(ctx, i, "test-bucket", firstPage.Contents)
34+
err := tracker.UpdateObjectCompletion(ctx, i, "test-bucket", "", firstPage.Contents)
3535
assert.NoError(t, err)
3636
}
3737

@@ -50,7 +50,7 @@ func TestCheckpointerResumption(t *testing.T) {
5050

5151
// Process remaining objects.
5252
for i := range len(resumePage.Contents) {
53-
err := resumeTracker.UpdateObjectCompletion(ctx, i, "test-bucket", resumePage.Contents)
53+
err := resumeTracker.UpdateObjectCompletion(ctx, i, "test-bucket", "", resumePage.Contents)
5454
assert.NoError(t, err)
5555
}
5656

@@ -244,7 +244,7 @@ func TestCheckpointerUpdate(t *testing.T) {
244244
}
245245
}
246246

247-
err := tracker.UpdateObjectCompletion(ctx, tt.completedIdx, "test-bucket", page.Contents)
247+
err := tracker.UpdateObjectCompletion(ctx, tt.completedIdx, "test-bucket", "", page.Contents)
248248
assert.NoError(t, err, "Unexpected error updating progress")
249249

250250
var info ResumeInfo
@@ -258,6 +258,36 @@ func TestCheckpointerUpdate(t *testing.T) {
258258
}
259259
}
260260

261+
func TestCheckpointerUpdateUnitScan(t *testing.T) {
262+
ctx := context.Background()
263+
progress := new(sources.Progress)
264+
tracker := NewCheckpointer(ctx, progress)
265+
tracker.SetIsUnitScan(true)
266+
267+
page := &s3.ListObjectsV2Output{
268+
Contents: make([]s3types.Object, 3),
269+
}
270+
for i := range 3 {
271+
key := fmt.Sprintf("key-%d", i)
272+
page.Contents[i] = s3types.Object{Key: &key}
273+
}
274+
275+
// Complete first object.
276+
err := tracker.UpdateObjectCompletion(ctx, 0, "test-bucket", "test-role", page.Contents)
277+
assert.NoError(t, err, "Unexpected error updating progress")
278+
279+
var info map[string]string
280+
err = json.Unmarshal([]byte(progress.EncodedResumeInfo), &info)
281+
var gotUnitID, gotStartAfter string
282+
for k, v := range info {
283+
gotUnitID = k
284+
gotStartAfter = v
285+
}
286+
assert.NoError(t, err, "Failed to decode resume info")
287+
assert.Equal(t, "test-role|test-bucket", gotUnitID, "Incorrect unit ID")
288+
assert.Equal(t, "key-0", gotStartAfter, "Incorrect resume point")
289+
}
290+
261291
func TestComplete(t *testing.T) {
262292
tests := []struct {
263293
name string

0 commit comments

Comments
 (0)