Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
.idea/
.env
.env
__pycache__
*.egg-info
models
lightning_logs/
wandb
68 changes: 53 additions & 15 deletions sign_language_segmentation/bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,45 @@ def save_pose_segments(tiers: dict, tier_id: str, input_file_path: Path) -> None
with out_path.open("wb") as f:
cropped_pose.write(f)

def time_str_to_ms(time_str):
"""
Convert a time string in "HH:MM:SS.mmm" format to milliseconds.
"""
h, m, s = time_str.split(':')
s, ms = s.split('.')
return int(h) * 3600000 + int(m) * 60000 + int(s) * 1000 + int(ms)

def add_subtitle_tier(eaf, subtitle_path, tier_name):
if not os.path.exists(subtitle_path):
return

# Add the tier only once.
eaf.add_tier(tier_name)
ext = os.path.splitext(subtitle_path)[1].lower()

if ext == ".vtt":
try:
import webvtt
except ImportError:
raise ImportError("webvtt-py is required to parse VTT files. Install it via 'pip install webvtt-py'.")
for caption in webvtt.read(subtitle_path):
# Instead of relying on caption.start_in_seconds, we parse the time string directly for precision.
start = time_str_to_ms(caption.start)
end = time_str_to_ms(caption.end)
# HACK: avoid zero or negative length annotation
if start >= end:
end = start + 1
eaf.add_annotation(tier_name, start, end, caption.text)
else:
import srt
with open(subtitle_path, "r", encoding="utf-8-sig") as infile:
for subtitle in srt.parse(infile):
start = subtitle.start.total_seconds()
end = subtitle.end.total_seconds()
# HACK: avoid zero or negative length annotation
if start >= end:
end = start + 1
eaf.add_annotation(tier_name, int(start * 1000), int(end * 1000), subtitle.content)

def get_args():
parser = argparse.ArgumentParser()
Expand All @@ -87,13 +126,16 @@ def get_args():
)
parser.add_argument("--video", default=None, required=False, type=str, help="path to video file")
parser.add_argument("--subtitles", default=None, required=False, type=str, help="path to subtitle file")
parser.add_argument("--model", default=DEFAULT_MODEL, required=False, type=str, help="path to model file")
parser.add_argument("--subtitles-corrected", default=None, required=False, type=str, help="path to subtitle file")
parser.add_argument("--model", default="model_E1s-1.pth", required=False, type=str, help="path to model file")
parser.add_argument("--sign-b-threshold", default=60, type=int)
parser.add_argument("--sign-o-threshold", default=50, type=int)
parser.add_argument("--no-pose-link", action="store_true", help="whether to link the pose file")

return parser.parse_args()


def segment_pose(pose: Pose, model: str = DEFAULT_MODEL, verbose=True):
def segment_pose(pose: Pose, model: str = DEFAULT_MODEL, verbose=True, sign_b_threshold=60, sign_o_threshold=50):
if "E4" in model:
pose = process_pose(pose, optical_flow=True, hand_normalization=True)
else:
Expand All @@ -108,7 +150,7 @@ def segment_pose(pose: Pose, model: str = DEFAULT_MODEL, verbose=True):
print("Estimating segments ...")
probs = predict(model, pose)

sign_segments = probs_to_segments(probs["sign"], 60, 50)
sign_segments = probs_to_segments(probs["sign"], sign_b_threshold, sign_o_threshold)
sentence_segments = probs_to_segments(probs["sentence"], 90, 90)

if verbose:
Expand Down Expand Up @@ -143,7 +185,7 @@ def main():
with open(args.pose, "rb") as f:
pose = Pose.read(f.read())

eaf, tiers = segment_pose(pose, model=args.model)
eaf, tiers = segment_pose(pose, model=args.model, sign_b_threshold=args.sign_b_threshold, sign_o_threshold=args.sign_o_threshold)

if args.video is not None:
mimetype = None # pympi is not familiar with mp4 files
Expand All @@ -154,18 +196,14 @@ def main():
if not args.no_pose_link:
eaf.add_linked_file(args.pose, mimetype="application/pose")

if args.subtitles and os.path.exists(args.subtitles):
import srt
if args.save_segments:
print(f"Saving {args.save_segments} cropped .pose files")
save_pose_segments(tiers, tier_id=args.save_segments, input_file_path=args.pose)

eaf.add_tier("SUBTITLE")
# open with explicit encoding,
# as directed in https://github.com/cdown/srt/blob/master/srt_tools/utils.py#L155-L160
# see also https://github.com/cdown/srt/issues/67, https://github.com/cdown/srt/issues/36
with open(args.subtitles, "r", encoding="utf-8-sig") as infile:
for subtitle in srt.parse(infile):
start = subtitle.start.total_seconds()
end = subtitle.end.total_seconds()
eaf.add_annotation("SUBTITLE", int(start * 1000), int(end * 1000), subtitle.content)
if args.subtitles:
add_subtitle_tier(eaf, args.subtitles, "SUBTITLE")
if args.subtitles_corrected:
add_subtitle_tier(eaf, args.subtitles_corrected, "SUBTITLE_CORRECTED")

print("Saving .eaf to disk ...")
eaf.to_file(args.elan)
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
4 changes: 2 additions & 2 deletions sign_language_segmentation/src/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ def boolean_string(s):
parser.add_argument('--gpus', type=int, default=1, help='how many gpus')
parser.add_argument('--epochs', type=int, default=100, help='how many epochs')
parser.add_argument('--patience', type=int, default=20, help='how many epochs as the patience for early stopping')
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--batch_size_devtest', type=int, default=20,
help='batch size for dev and test (by default run all in one batch)')
parser.add_argument('--learning_rate', type=float, default=1e-3, help='optimizer learning rate')
parser.add_argument('--lr_scheduler', type=str, default='none', help='optimizer learning rate scheduler')

# Data Arguments
parser.add_argument('--dataset', choices=['dgs_corpus', 'mediapi_skel'],
parser.add_argument('--dataset', choices=['dgs_corpus', 'mediapi_skel', 'bobsl_cslr', 'bslcp'],
default='dgs_corpus', help='which dataset to use?')
parser.add_argument('--data_dir', help='which dir to store the dataset?')
parser.add_argument('--data_dev', type=boolean_string, default=False,
Expand Down
Empty file.
Loading