diff --git a/.gitignore b/.gitignore index 96ee987..36e3fa3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,10 @@ .idea/ build/ pose_evaluation.egg-info/ -**/__pycache__/ \ No newline at end of file +**/__pycache__/ +.coverage +.vscode/ +coverage.lcov +**/test_data/ +*.npz +*.code-workspace diff --git a/pose_evaluation/metrics/.gitignore b/pose_evaluation/metrics/.gitignore index cd78447..2b29f27 100644 --- a/pose_evaluation/metrics/.gitignore +++ b/pose_evaluation/metrics/.gitignore @@ -1 +1 @@ -temp/ \ No newline at end of file +tests diff --git a/pose_evaluation/utils/conftest.py b/pose_evaluation/utils/conftest.py new file mode 100644 index 0000000..2d3031b --- /dev/null +++ b/pose_evaluation/utils/conftest.py @@ -0,0 +1,51 @@ +import json +import copy +from pathlib import Path +from typing import List, Dict + +import pytest +from pose_format import Pose +from pose_format.utils.generic import fake_pose +from pose_format.utils.openpose_135 import ( + OpenPose_Components as openpose_135_components, +) + +from pose_evaluation.utils.pose_utils import load_pose_file + + + +utils_test_data_dir = Path(__file__).parent / "test" / "test_data" + + +@pytest.fixture(scope="function") +def mediapipe_poses_test_data_paths() -> List[Path]: + pose_file_paths = list(utils_test_data_dir.glob("*.pose")) + return pose_file_paths + + +@pytest.fixture(scope="function") +def mediapipe_poses_test_data(mediapipe_poses_test_data_paths) -> List[Pose]: + original_poses = [ + load_pose_file(pose_path) for pose_path in mediapipe_poses_test_data_paths + ] + # I ran into issues where if one test would modify a Pose, it would affect other tests. + # specifically, pose.header.components[0].name = unsupported_component_name in test_detect_format + # this ensures we get a fresh object each time. + return copy.deepcopy(original_poses) + + +@pytest.fixture +def standard_mediapipe_components_dict() -> Dict[str, List[str]]: + format_json = utils_test_data_dir / "mediapipe_components_and_points.json" + with open(format_json, "r", encoding="utf-8") as f: + return json.load(f) + + +@pytest.fixture +def fake_openpose_poses(count: int = 3) -> List[Pose]: + return [fake_pose(30) for _ in range(count)] + + +@pytest.fixture +def fake_openpose_135_poses(count: int = 3) -> List[Pose]: + return [fake_pose(30, components=openpose_135_components) for _ in range(count)] diff --git a/pose_evaluation/utils/pose_utils.py b/pose_evaluation/utils/pose_utils.py new file mode 100644 index 0000000..cfc8147 --- /dev/null +++ b/pose_evaluation/utils/pose_utils.py @@ -0,0 +1,92 @@ +from pathlib import Path +from typing import List, Tuple, Dict, Iterable +from collections import defaultdict +import numpy as np +from numpy import ma +from pose_format import Pose + + +def pose_remove_world_landmarks(pose: Pose) -> Pose: + return pose.remove_components(["POSE_WORLD_LANDMARKS"]) + + +def get_component_names_and_points_dict( + pose: Pose, +) -> Tuple[List[str], Dict[str, List[str]]]: + component_names = [] + points_dict = defaultdict(list) + for component in pose.header.components: + component_names.append(component.name) + + for point in component.points: + points_dict[component.name].append(point) + + return component_names, points_dict + + +def get_face_and_hands_from_pose(pose: Pose) -> Pose: + # based on MediaPipe Holistic format. + components_to_keep = [ + "FACE_LANDMARKS", + "LEFT_HAND_LANDMARKS", + "RIGHT_HAND_LANDMARKS", + ] + return pose.get_components(components_to_keep) + + +def load_pose_file(pose_path: Path) -> Pose: + pose_path = Path(pose_path).resolve() + with pose_path.open("rb") as f: + pose = Pose.read(f.read()) + return pose + + +def reduce_poses_to_intersection( + poses: Iterable[Pose], +) -> List[Pose]: + poses = list(poses) # get a list, no need to copy + + # look at the first pose + component_names = {c.name for c in poses[0].header.components} + points = {c.name: set(c.points) for c in poses[0].header.components} + + # remove anything that other poses don't have + for pose in poses[1:]: + component_names.intersection_update({c.name for c in pose.header.components}) + for component in pose.header.components: + points[component.name].intersection_update(set(component.points)) + + # change datatypes to match get_components, then update the poses + points_dict = {} + for c_name in points.keys(): + points_dict[c_name] = list(points[c_name]) + poses = [pose.get_components(list(component_names), points_dict) for pose in poses] + return poses + + +def zero_pad_shorter_poses(poses: Iterable[Pose]) -> List[Pose]: + poses = [pose.copy() for pose in poses] + # arrays = [pose.body.data for pose in poses] + + # first dimension is frames. Then People, joint-points, XYZ or XY + max_frame_count = max(len(pose.body.data) for pose in poses) + # Pad the shorter array with zeros + for pose in poses: + if len(pose.body.data) < max_frame_count: + desired_shape = list(pose.body.data.shape) + desired_shape[0] = max_frame_count - len(pose.body.data) + padding_tensor = ma.zeros(desired_shape) + padding_tensor_conf = ma.ones(desired_shape[:-1]) + pose.body.data = ma.concatenate([pose.body.data, padding_tensor], axis=0) + pose.body.confidence = ma.concatenate( + [pose.body.confidence, padding_tensor_conf] + ) + return poses + + +def pose_hide_low_conf(pose: Pose, confidence_threshold: float = 0.2) -> None: + mask = pose.body.confidence <= confidence_threshold + pose.body.confidence[mask] = 0 + stacked_confidence = np.stack([mask, mask, mask], axis=3) + masked_data = ma.masked_array(pose.body.data, mask=stacked_confidence) + pose.body.data = masked_data diff --git a/pose_evaluation/utils/test/test_data/colin-1-HOUSE.pose b/pose_evaluation/utils/test/test_data/colin-1-HOUSE.pose new file mode 100644 index 0000000..8f74698 Binary files /dev/null and b/pose_evaluation/utils/test/test_data/colin-1-HOUSE.pose differ diff --git a/pose_evaluation/utils/test/test_data/colin-2-HOUSE.pose b/pose_evaluation/utils/test/test_data/colin-2-HOUSE.pose new file mode 100644 index 0000000..67d12e3 Binary files /dev/null and b/pose_evaluation/utils/test/test_data/colin-2-HOUSE.pose differ diff --git a/pose_evaluation/utils/test/test_data/colin-SAD.pose b/pose_evaluation/utils/test/test_data/colin-SAD.pose new file mode 100644 index 0000000..f00fc7f Binary files /dev/null and b/pose_evaluation/utils/test/test_data/colin-SAD.pose differ diff --git a/pose_evaluation/utils/test/test_data/mediapipe_components_and_points.json b/pose_evaluation/utils/test/test_data/mediapipe_components_and_points.json new file mode 100644 index 0000000..d42dddc --- /dev/null +++ b/pose_evaluation/utils/test/test_data/mediapipe_components_and_points.json @@ -0,0 +1,588 @@ +{ + "POSE_LANDMARKS": [ + "NOSE", + "LEFT_EYE_INNER", + "LEFT_EYE", + "LEFT_EYE_OUTER", + "RIGHT_EYE_INNER", + "RIGHT_EYE", + "RIGHT_EYE_OUTER", + "LEFT_EAR", + "RIGHT_EAR", + "MOUTH_LEFT", + "MOUTH_RIGHT", + "LEFT_SHOULDER", + "RIGHT_SHOULDER", + "LEFT_ELBOW", + "RIGHT_ELBOW", + "LEFT_WRIST", + "RIGHT_WRIST", + "LEFT_PINKY", + "RIGHT_PINKY", + "LEFT_INDEX", + "RIGHT_INDEX", + "LEFT_THUMB", + "RIGHT_THUMB", + "LEFT_HIP", + "RIGHT_HIP", + "LEFT_KNEE", + "RIGHT_KNEE", + "LEFT_ANKLE", + "RIGHT_ANKLE", + "LEFT_HEEL", + "RIGHT_HEEL", + "LEFT_FOOT_INDEX", + "RIGHT_FOOT_INDEX" + ], + "FACE_LANDMARKS": [ + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "10", + "11", + "12", + "13", + "14", + "15", + "16", + "17", + "18", + "19", + "20", + "21", + "22", + "23", + "24", + "25", + "26", + "27", + "28", + "29", + "30", + "31", + "32", + "33", + "34", + "35", + "36", + "37", + "38", + "39", + "40", + "41", + "42", + "43", + "44", + "45", + "46", + "47", + "48", + "49", + "50", + "51", + "52", + "53", + "54", + "55", + "56", + "57", + "58", + "59", + "60", + "61", + "62", + "63", + "64", + "65", + "66", + "67", + "68", + "69", + "70", + "71", + "72", + "73", + "74", + "75", + "76", + "77", + "78", + "79", + "80", + "81", + "82", + "83", + "84", + "85", + "86", + "87", + "88", + "89", + "90", + "91", + "92", + "93", + "94", + "95", + "96", + "97", + "98", + "99", + "100", + "101", + "102", + "103", + "104", + "105", + "106", + "107", + "108", + "109", + "110", + "111", + "112", + "113", + "114", + "115", + "116", + "117", + "118", + "119", + "120", + "121", + "122", + "123", + "124", + "125", + "126", + "127", + "128", + "129", + "130", + "131", + "132", + "133", + "134", + "135", + "136", + "137", + "138", + "139", + "140", + "141", + "142", + "143", + "144", + "145", + "146", + "147", + "148", + "149", + "150", + "151", + "152", + "153", + "154", + "155", + "156", + "157", + "158", + "159", + "160", + "161", + "162", + "163", + "164", + "165", + "166", + "167", + "168", + "169", + "170", + "171", + "172", + "173", + "174", + "175", + "176", + "177", + "178", + "179", + "180", + "181", + "182", + "183", + "184", + "185", + "186", + "187", + "188", + "189", + "190", + "191", + "192", + "193", + "194", + "195", + "196", + "197", + "198", + "199", + "200", + "201", + "202", + "203", + "204", + "205", + "206", + "207", + "208", + "209", + "210", + "211", + "212", + "213", + "214", + "215", + "216", + "217", + "218", + "219", + "220", + "221", + "222", + "223", + "224", + "225", + "226", + "227", + "228", + "229", + "230", + "231", + "232", + "233", + "234", + "235", + "236", + "237", + "238", + "239", + "240", + "241", + "242", + "243", + "244", + "245", + "246", + "247", + "248", + "249", + "250", + "251", + "252", + "253", + "254", + "255", + "256", + "257", + "258", + "259", + "260", + "261", + "262", + "263", + "264", + "265", + "266", + "267", + "268", + "269", + "270", + "271", + "272", + "273", + "274", + "275", + "276", + "277", + "278", + "279", + "280", + "281", + "282", + "283", + "284", + "285", + "286", + "287", + "288", + "289", + "290", + "291", + "292", + "293", + "294", + "295", + "296", + "297", + "298", + "299", + "300", + "301", + "302", + "303", + "304", + "305", + "306", + "307", + "308", + "309", + "310", + "311", + "312", + "313", + "314", + "315", + "316", + "317", + "318", + "319", + "320", + "321", + "322", + "323", + "324", + "325", + "326", + "327", + "328", + "329", + "330", + "331", + "332", + "333", + "334", + "335", + "336", + "337", + "338", + "339", + "340", + "341", + "342", + "343", + "344", + "345", + "346", + "347", + "348", + "349", + "350", + "351", + "352", + "353", + "354", + "355", + "356", + "357", + "358", + "359", + "360", + "361", + "362", + "363", + "364", + "365", + "366", + "367", + "368", + "369", + "370", + "371", + "372", + "373", + "374", + "375", + "376", + "377", + "378", + "379", + "380", + "381", + "382", + "383", + "384", + "385", + "386", + "387", + "388", + "389", + "390", + "391", + "392", + "393", + "394", + "395", + "396", + "397", + "398", + "399", + "400", + "401", + "402", + "403", + "404", + "405", + "406", + "407", + "408", + "409", + "410", + "411", + "412", + "413", + "414", + "415", + "416", + "417", + "418", + "419", + "420", + "421", + "422", + "423", + "424", + "425", + "426", + "427", + "428", + "429", + "430", + "431", + "432", + "433", + "434", + "435", + "436", + "437", + "438", + "439", + "440", + "441", + "442", + "443", + "444", + "445", + "446", + "447", + "448", + "449", + "450", + "451", + "452", + "453", + "454", + "455", + "456", + "457", + "458", + "459", + "460", + "461", + "462", + "463", + "464", + "465", + "466", + "467" + ], + "LEFT_HAND_LANDMARKS": [ + "WRIST", + "THUMB_CMC", + "THUMB_MCP", + "THUMB_IP", + "THUMB_TIP", + "INDEX_FINGER_MCP", + "INDEX_FINGER_PIP", + "INDEX_FINGER_DIP", + "INDEX_FINGER_TIP", + "MIDDLE_FINGER_MCP", + "MIDDLE_FINGER_PIP", + "MIDDLE_FINGER_DIP", + "MIDDLE_FINGER_TIP", + "RING_FINGER_MCP", + "RING_FINGER_PIP", + "RING_FINGER_DIP", + "RING_FINGER_TIP", + "PINKY_MCP", + "PINKY_PIP", + "PINKY_DIP", + "PINKY_TIP" + ], + "RIGHT_HAND_LANDMARKS": [ + "WRIST", + "THUMB_CMC", + "THUMB_MCP", + "THUMB_IP", + "THUMB_TIP", + "INDEX_FINGER_MCP", + "INDEX_FINGER_PIP", + "INDEX_FINGER_DIP", + "INDEX_FINGER_TIP", + "MIDDLE_FINGER_MCP", + "MIDDLE_FINGER_PIP", + "MIDDLE_FINGER_DIP", + "MIDDLE_FINGER_TIP", + "RING_FINGER_MCP", + "RING_FINGER_PIP", + "RING_FINGER_DIP", + "RING_FINGER_TIP", + "PINKY_MCP", + "PINKY_PIP", + "PINKY_DIP", + "PINKY_TIP" + ], + "POSE_WORLD_LANDMARKS": [ + "NOSE", + "LEFT_EYE_INNER", + "LEFT_EYE", + "LEFT_EYE_OUTER", + "RIGHT_EYE_INNER", + "RIGHT_EYE", + "RIGHT_EYE_OUTER", + "LEFT_EAR", + "RIGHT_EAR", + "MOUTH_LEFT", + "MOUTH_RIGHT", + "LEFT_SHOULDER", + "RIGHT_SHOULDER", + "LEFT_ELBOW", + "RIGHT_ELBOW", + "LEFT_WRIST", + "RIGHT_WRIST", + "LEFT_PINKY", + "RIGHT_PINKY", + "LEFT_INDEX", + "RIGHT_INDEX", + "LEFT_THUMB", + "RIGHT_THUMB", + "LEFT_HIP", + "RIGHT_HIP", + "LEFT_KNEE", + "RIGHT_KNEE", + "LEFT_ANKLE", + "RIGHT_ANKLE", + "LEFT_HEEL", + "RIGHT_HEEL", + "LEFT_FOOT_INDEX", + "RIGHT_FOOT_INDEX" + ] +} \ No newline at end of file diff --git a/pose_evaluation/utils/test_pose_utils.py b/pose_evaluation/utils/test_pose_utils.py new file mode 100644 index 0000000..cd54cea --- /dev/null +++ b/pose_evaluation/utils/test_pose_utils.py @@ -0,0 +1,290 @@ +from typing import List, Dict + +from pathlib import Path + +import pytest +import numpy as np +import numpy.ma as ma # pylint: disable=consider-using-from-import + +from pose_format import Pose +from pose_format.utils.generic import detect_known_pose_format, pose_hide_legs +from pose_evaluation.utils.pose_utils import ( + load_pose_file, + pose_remove_world_landmarks, + pose_hide_low_conf, + get_face_and_hands_from_pose, + reduce_poses_to_intersection, + get_component_names_and_points_dict, + zero_pad_shorter_poses, + set_masked_to_origin_position, +) + + +def test_load_poses_mediapipe( + mediapipe_poses_test_data_paths: List[Path], + standard_mediapipe_components_dict: Dict[str, List[str]], +): + + poses = [load_pose_file(pose_path) for pose_path in mediapipe_poses_test_data_paths] + + assert len(poses) == 3 + + for pose in poses: + # do they all have headers? + assert pose.header is not None + + # check if the expected components are there. + for component in pose.header.components: + # should have specific expected components + assert component.name in standard_mediapipe_components_dict + + # should have specific expected points + assert sorted(component.points) == sorted( + standard_mediapipe_components_dict[component.name] + ) + + # checking the data: + # Frames, People, Points, Dims + assert pose.body.data.ndim == 4 + + # all frames have the standard shape? + assert all(frame.shape == (1, 576, 3) for frame in pose.body.data) + + +def test_remove_specific_landmarks_mediapipe( + mediapipe_poses_test_data: List[Pose], + standard_mediapipe_components_dict: Dict[str, List[str]], +): + for pose in mediapipe_poses_test_data: + component_count = len(pose.header.components) + assert component_count == len(standard_mediapipe_components_dict.keys()) + for component_name in standard_mediapipe_components_dict.keys(): + pose_with_component_removed =pose.remove_components([str(component_name)]) + assert component_name not in pose_with_component_removed.header.components + assert ( + len(pose_with_component_removed.header.components) + == component_count - 1 + ) + + +def test_pose_copy(mediapipe_poses_test_data: List[Pose]): + for pose in mediapipe_poses_test_data: + copy = pose.copy() + + assert copy != pose # Not the same object + assert pose.body != copy.body # also not the same + assert np.array_equal( + copy.body.data, pose.body.data + ) # the data should have the same values + + assert sorted([c.name for c in pose.header.components]) == sorted( + [c.name for c in copy.header.components] + ) # same components + assert ( + copy.header.total_points() == pose.header.total_points() + ) # same number of points + + +def test_pose_remove_legs(mediapipe_poses_test_data: List[Pose]): + points_that_should_be_removed = ["LEFT_KNEE", "LEFT_HEEL", "LEFT_FOOT", "LEFT_TOE", "LEFT_FOOT_INDEX", + "RIGHT_KNEE", "RIGHT_HEEL", "RIGHT_FOOT", "RIGHT_TOE", "RIGHT_FOOT_INDEX",] + for pose in mediapipe_poses_test_data: + c_names = [c.name for c in pose.header.components] + assert "POSE_LANDMARKS" in c_names + pose_landmarks_index = c_names.index("POSE_LANDMARKS") + assert "LEFT_KNEE" in pose.header.components[pose_landmarks_index].points + + pose_with_legs_removed = pose_hide_legs(pose, remove=True) + assert pose_with_legs_removed != pose + assert pose_with_legs_removed.header != pose.header + assert pose_with_legs_removed.header.components != pose.header.components + new_c_names = [c.name for c in pose_with_legs_removed.header.components] + assert "POSE_LANDMARKS" in new_c_names + + for component in pose_with_legs_removed.header.components: + point_names = [point.upper() for point in component.points] + for point_name in point_names: + for point_that_should_be_hidden in points_that_should_be_removed: + assert point_that_should_be_hidden not in point_name, f"{component.name}: {point_names}" + + +def test_pose_remove_legs_openpose(fake_openpose_poses): + points_that_should_be_removed = ["Hip", "Knee", "Ankle", "BigToe", "SmallToe", "Heel"] + for pose in fake_openpose_poses: + pose_with_legs_removed = pose_hide_legs(pose, remove=True) + + for component in pose_with_legs_removed.header.components: + point_names = list(point for point in component.points) + for point_name in point_names: + for point_that_should_be_hidden in points_that_should_be_removed: + assert point_that_should_be_hidden not in point_name, f"{component.name}: {point_names}" + + + +def test_reduce_pose_components_to_intersection( + mediapipe_poses_test_data: List[Pose], + standard_mediapipe_components_dict: Dict[str, List[str]], +): + + test_poses_with_one_reduced = [pose.copy() for pose in mediapipe_poses_test_data] + + pose_with_only_face_and_hands_and_no_wrist = get_face_and_hands_from_pose( + test_poses_with_one_reduced.pop() + ) + + c_names, p_dict = get_component_names_and_points_dict( + pose_with_only_face_and_hands_and_no_wrist + ) + + new_p_dict = {} + for c_name, p_list in p_dict.items(): + new_p_dict[c_name] = [ + point_name for point_name in p_list if "WRIST" not in point_name + ] + + pose_with_only_face_and_hands_and_no_wrist = ( + pose_with_only_face_and_hands_and_no_wrist.get_components(c_names, new_p_dict) + ) + + test_poses_with_one_reduced.append(pose_with_only_face_and_hands_and_no_wrist) + assert len(mediapipe_poses_test_data) == len(test_poses_with_one_reduced) + + original_component_count = len( + standard_mediapipe_components_dict.keys() + ) # 5, at time of writing + + target_component_count = 3 # face, left hand, right hand + assert ( + len(pose_with_only_face_and_hands_and_no_wrist.header.components) + == target_component_count + ) + + target_point_count = ( + pose_with_only_face_and_hands_and_no_wrist.header.total_points() + ) + + reduced_poses = reduce_poses_to_intersection( + test_poses_with_one_reduced + ) + for reduced_pose in reduced_poses: + assert len(reduced_pose.header.components) == target_component_count + assert reduced_pose.header.total_points() == target_point_count + + # check if the originals are unaffected + assert all( + len(pose.header.components) == original_component_count + for pose in mediapipe_poses_test_data + ) + + +def test_remove_world_landmarks(mediapipe_poses_test_data: List[Pose]): + for pose in mediapipe_poses_test_data: + component_names = [c.name for c in pose.header.components] + starting_component_count = len(pose.header.components) + assert "POSE_WORLD_LANDMARKS" in component_names + + pose = pose_remove_world_landmarks(pose) + component_names = [c.name for c in pose.header.components] + assert "POSE_WORLD_LANDMARKS" not in component_names + ending_component_count = len(pose.header.components) + + assert ending_component_count == starting_component_count - 1 + + +def test_remove_one_point_and_one_component(mediapipe_poses_test_data: List[Pose]): + component_to_drop = "POSE_WORLD_LANDMARKS" + point_to_drop = "LEFT_KNEE" + for pose in mediapipe_poses_test_data: + original_component_names, original_points_dict = ( + get_component_names_and_points_dict(pose) + ) + + assert component_to_drop in original_component_names + assert point_to_drop in original_points_dict["POSE_LANDMARKS"] + reduced_pose = pose.remove_components(component_to_drop, {"POSE_LANDMARKS": [point_to_drop]}) + new_component_names, new_points_dict = get_component_names_and_points_dict( + reduced_pose + ) + assert component_to_drop not in new_component_names + assert point_to_drop not in new_points_dict["POSE_LANDMARKS"] + + +def test_detect_format( + fake_openpose_poses, fake_openpose_135_poses, mediapipe_poses_test_data +): + for pose in fake_openpose_poses: + assert detect_known_pose_format(pose) == "openpose" + + for pose in fake_openpose_135_poses: + assert detect_known_pose_format(pose) == "openpose_135" + + for pose in mediapipe_poses_test_data: + assert detect_known_pose_format(pose) == "holistic" + + for pose in mediapipe_poses_test_data: + unsupported_component_name = "UNSUPPORTED" + pose.header.components[0].name = unsupported_component_name + pose = pose.get_components(["UNSUPPORTED"]) + assert len(pose.header.components) == 1 + + with pytest.raises( + ValueError, match="Could not detect pose format, unknown pose header schema with component names" + ): + detect_known_pose_format(pose) + + +def test_set_masked_to_origin_pos(mediapipe_poses_test_data: List[Pose]): + # Create a copy of the original poses for comparison + originals = [pose.copy() for pose in mediapipe_poses_test_data] + + # Apply the transformation + poses = [set_masked_to_origin_position(pose) for pose in mediapipe_poses_test_data] + + for original, transformed in zip(originals, poses): + # 1. Ensure the transformed data is still a MaskedArray + assert isinstance(transformed.body.data, np.ma.MaskedArray) + + # # 2. Ensure the mask is now all False, meaning data _exists_ though its _value_ is now zero + # assert np.ma.all(~transformed.body.data.mask) + # assert original.body.data.mask.sum() == 0 + assert transformed.body.data.mask.sum() == 0 + + # 3. Check the shape matches the original + assert transformed.body.data.shape == original.body.data.shape + + # 4. Validate masked positions in the original are now zeros + assert ma.all(transformed.body.data.data[original.body.data.mask] == 0) + + # 5. Validate unmasked positions in the original remain unchanged + assert ma.all( + transformed.body.data.data[~original.body.data.mask] + == original.body.data.data[~original.body.data.mask] + ) + + +def test_hide_low_conf(mediapipe_poses_test_data: List[Pose]): + copies = [pose.copy() for pose in mediapipe_poses_test_data] + for pose, copy in zip(mediapipe_poses_test_data, copies): + pose_hide_low_conf(pose, 1.0) + + assert np.array_equal(pose.body.confidence, copy.body.confidence) is False + + +def test_zero_pad_shorter_poses(mediapipe_poses_test_data: List[Pose]): + copies = [pose.copy() for pose in mediapipe_poses_test_data] + + max_len = max(len(pose.body.data) for pose in mediapipe_poses_test_data) + padded_poses = zero_pad_shorter_poses(mediapipe_poses_test_data) + + for i, padded_pose in enumerate(padded_poses): + assert ( + mediapipe_poses_test_data[i] != padded_poses[i] + ) # shouldn't be the same object + old_length = len(copies[i].body.data) + new_length = len(padded_pose.body.data) + assert new_length == max_len + if old_length == new_length: + assert old_length == max_len + + # does the confidence match? + assert padded_pose.body.confidence.shape == padded_pose.body.data.shape[:-1]