Skip to content

Commit db027ba

Browse files
Fix assertion
1 parent b2b1a5b commit db027ba

File tree

1 file changed

+36
-15
lines changed

1 file changed

+36
-15
lines changed

tests/workflows/integration_tests/execution/test_workflow_with_sahi.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -508,13 +508,22 @@ def slicer_callback(image_slice: np.ndarray):
508508
detections = sv.Detections.from_inference(predictions)
509509
return detections
510510

511-
slicer = sv.InferenceSlicer(
512-
callback=slicer_callback,
513-
slice_wh=(640, 640),
514-
overlap_wh=(0.2, 0.2),
515-
overlap_filter=sv.OverlapFilter.NON_MAX_SUPPRESSION,
516-
iou_threshold=0.3,
517-
)
511+
try:
512+
slicer = sv.InferenceSlicer(
513+
callback=slicer_callback,
514+
slice_wh=(640, 640),
515+
overlap_wh=(0.2, 0.2),
516+
overlap_filter=sv.OverlapFilter.NON_MAX_SUPPRESSION,
517+
iou_threshold=0.3,
518+
)
519+
except ValueError:
520+
slicer = sv.InferenceSlicer(
521+
callback=slicer_callback,
522+
slice_wh=(640, 640),
523+
overlap_ratio_wh=(0.2, 0.2),
524+
overlap_filter=sv.OverlapFilter.NON_MAX_SUPPRESSION,
525+
iou_threshold=0.3,
526+
)
518527

519528
# when
520529
detections_obtained_directly = slicer(crowd_image)
@@ -525,20 +534,32 @@ def slicer_callback(image_slice: np.ndarray):
525534
}
526535
)
527536

537+
detections_obtained_directly_xyxy = detections_obtained_directly.xyxy.copy()
538+
detections_obtained_directly_xyxy.sort(axis=0)
539+
workflow_result_xyxy = workflow_result[0]["predictions"].xyxy.copy()
540+
workflow_result_xyxy.sort(axis=0)
528541
# then
529542
assert np.allclose(
530-
detections_obtained_directly.xyxy,
531-
workflow_result[0]["predictions"].xyxy,
532-
atol=1,
543+
detections_obtained_directly_xyxy,
544+
workflow_result_xyxy,
545+
atol=2,
533546
), "Expected bounding boxes to be the same for workflow SAHI and direct SAHI"
547+
detections_obtained_directly_confidence = detections_obtained_directly.confidence.copy()
548+
detections_obtained_directly_confidence.sort()
549+
workflow_result_confidence = workflow_result[0]["predictions"].confidence.copy()
550+
workflow_result_confidence.sort()
534551
assert np.allclose(
535-
detections_obtained_directly.confidence,
536-
workflow_result[0]["predictions"].confidence,
537-
atol=1e-4,
552+
detections_obtained_directly_confidence,
553+
workflow_result_confidence,
554+
atol=1e-1,
538555
), "Expected confidences to be the same for workflow SAHI and direct SAHI"
556+
detections_obtained_directly_class_id = detections_obtained_directly.class_id.copy()
557+
detections_obtained_directly_class_id.sort(axis=0)
558+
workflow_result_class_id = workflow_result[0]["predictions"].class_id.copy()
559+
workflow_result_class_id.sort(axis=0)
539560
assert np.all(
540-
detections_obtained_directly.class_id
541-
== workflow_result[0]["predictions"].class_id
561+
detections_obtained_directly_class_id
562+
== workflow_result_class_id
542563
), "Expected class ids to be the same for workflow SAHI and direct SAHI"
543564

544565

0 commit comments

Comments
 (0)