Skip to content

Commit 6f2e2aa

Browse files
committed
added max_line_length parameter for .srt files
1 parent 0b5dcfd commit 6f2e2aa

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

whisper/transcribe.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,9 @@ def cli():
256256
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
257257
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
258258
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
259-
259+
260+
parser.add_argument("--max_line_length", type=optional_int, default=42, help="max amount of characters for a line in the subtitle files")
261+
260262
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
261263
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
262264

@@ -299,7 +301,9 @@ def cli():
299301
threads = args.pop("threads")
300302
if threads > 0:
301303
torch.set_num_threads(threads)
302-
304+
305+
output_max_line_length = args.pop("max_line_length")
306+
303307
from . import load_model
304308
model = load_model(model_name, device=device, download_root=model_dir)
305309

@@ -318,7 +322,7 @@ def cli():
318322

319323
# save SRT
320324
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
321-
write_srt(result["segments"], file=srt)
325+
write_srt(result["segments"], file=srt, max_line_length=output_max_line_length)
322326

323327

324328
if __name__ == '__main__':

whisper/utils.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def write_vtt(transcript: Iterator[dict], file: TextIO):
6161
)
6262

6363

64-
def write_srt(transcript: Iterator[dict], file: TextIO):
64+
def write_srt(transcript: Iterator[dict], file: TextIO, max_line_length: int = 42):
6565
"""
6666
Write a transcript to a file in SRT format.
6767
@@ -76,13 +76,52 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
7676
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
7777
write_srt(result["segments"], file=srt)
7878
"""
79+
80+
comma_split_threshold = int(float(max_line_length) * 0.75)
81+
7982
for i, segment in enumerate(transcript, start=1):
83+
84+
# left the .replace() here to not change unnecessarily
85+
# but I don't think it's needed?
86+
segment_text = segment['text'].strip().replace('-->', '->')
87+
88+
if len(segment_text) > max_line_length:
89+
segment_text = split_text_into_multiline(segment_text, max_line_length, comma_split_threshold)
90+
8091
# write srt lines
8192
print(
8293
f"{i}\n"
8394
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
8495
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
85-
f"{segment['text'].strip().replace('-->', '->')}\n",
96+
f"{segment_text}\n",
8697
file=file,
8798
flush=True,
8899
)
100+
101+
102+
def split_text_into_multiline(segment_text: str, max_line_length: int, comma_split_threshold: int):
103+
104+
words = segment_text.split(' ')
105+
106+
lines = [
107+
words[0]
108+
]
109+
110+
for word in words[1:]:
111+
current_line = lines[-1]
112+
113+
# start a new line if the last word ended with a comma,
114+
# and we're mostly through this line
115+
if current_line.endswith(',') and len(current_line) > comma_split_threshold:
116+
lines.append(word)
117+
continue
118+
119+
line_with_word = f'{current_line} {word}'
120+
121+
if len(line_with_word) > max_line_length:
122+
lines.append(word)
123+
else:
124+
lines[-1] = line_with_word
125+
126+
return '\n'.join(lines)
127+

0 commit comments

Comments
 (0)