Added more comments to the code

This commit is contained in:
ChaoticByte 2024-08-13 21:46:18 +02:00
parent cfec6bf64a
commit 464ede2444
No known key found for this signature in database

View file

@ -2,9 +2,12 @@
# Copyright (c) 2024 Julian Müller (ChaoticByte) # Copyright (c) 2024 Julian Müller (ChaoticByte)
# Disable FutureWarnings
import warnings import warnings
warnings.simplefilter(action='ignore', category=FutureWarning) warnings.simplefilter(action='ignore', category=FutureWarning)
# Imports
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from subprocess import check_call, DEVNULL from subprocess import check_call, DEVNULL
@ -15,15 +18,16 @@ from semantic_text_splitter import TextSplitter
from tokenizers import Tokenizer from tokenizers import Tokenizer
from transformers import pipeline from transformers import pipeline
# Some constant variables
NLP_MODEL = "facebook/bart-large-cnn" NLP_MODEL = "facebook/bart-large-cnn"
root_dir = Path(__file__).parent root_dir = Path(__file__).parent
whisper_cpp_binary = (root_dir / "vendor" / "whisper.cpp" / "main").__str__() whisper_cpp_binary = (root_dir / "vendor" / "whisper.cpp" / "main").__str__()
# Steps
# tasks
def convert_audio(media_file: str, output_file: str): def convert_audio(media_file: str, output_file: str):
'''Convert media to mono 16kHz pcm_s16le wav using ffmpeg'''
check_call([ check_call([
"ffmpeg", "ffmpeg",
"-hide_banner", "-hide_banner",
@ -35,6 +39,7 @@ def convert_audio(media_file: str, output_file: str):
output_file]) output_file])
def transcribe(model_file: str, audio_file: str, output_file: str): def transcribe(model_file: str, audio_file: str, output_file: str):
'''Transcribe audio file using whisper.cpp'''
check_call([ check_call([
whisper_cpp_binary, whisper_cpp_binary,
"-m", model_file, "-m", model_file,
@ -53,6 +58,7 @@ def cleanup_text(t: str) -> str:
return t return t
def split_text(t: str, max_tokens: int) -> List[str]: def split_text(t: str, max_tokens: int) -> List[str]:
'''Split text into semantic segments'''
tokenizer = Tokenizer.from_pretrained(NLP_MODEL) tokenizer = Tokenizer.from_pretrained(NLP_MODEL)
splitter = TextSplitter.from_huggingface_tokenizer( splitter = TextSplitter.from_huggingface_tokenizer(
tokenizer, (int(max_tokens*0.8), max_tokens)) tokenizer, (int(max_tokens*0.8), max_tokens))
@ -60,6 +66,7 @@ def split_text(t: str, max_tokens: int) -> List[str]:
return chunks return chunks
def summarize(chunks: List[str], summary_min: int, summary_max: int) -> str: def summarize(chunks: List[str], summary_min: int, summary_max: int) -> str:
'''Summarize all segments (chunks) using a language model'''
chunks_summarized = [] chunks_summarized = []
summ = pipeline("summarization", model=NLP_MODEL) summ = pipeline("summarization", model=NLP_MODEL)
for c in chunks: for c in chunks:
@ -67,9 +74,10 @@ def summarize(chunks: List[str], summary_min: int, summary_max: int) -> str:
summ(c, max_length=summary_max, min_length=summary_min, do_sample=False)[0]['summary_text'].strip()) summ(c, max_length=summary_max, min_length=summary_min, do_sample=False)[0]['summary_text'].strip())
return "\n".join(chunks_summarized) return "\n".join(chunks_summarized)
# # Main
if __name__ == "__main__": if __name__ == "__main__":
# parse commandline arguments
argp = ArgumentParser() argp = ArgumentParser()
argp.add_argument("--summin", metavar="n", type=int, default=10, help="The minimum lenght of a segment summary [10, min: 5]") argp.add_argument("--summin", metavar="n", type=int, default=10, help="The minimum lenght of a segment summary [10, min: 5]")
argp.add_argument("--summax", metavar="n", type=int, default=90, help="The maximum lenght of a segment summary [90, min: 5]") argp.add_argument("--summax", metavar="n", type=int, default=90, help="The maximum lenght of a segment summary [90, min: 5]")
@ -78,6 +86,7 @@ if __name__ == "__main__":
argp.add_argument("-i", required=True, metavar="filepath", type=Path, help="The path to the media file") argp.add_argument("-i", required=True, metavar="filepath", type=Path, help="The path to the media file")
argp.add_argument("-o", required=True, metavar="filepath", type=Path, help="Where to save the output text to") argp.add_argument("-o", required=True, metavar="filepath", type=Path, help="Where to save the output text to")
args = argp.parse_args() args = argp.parse_args()
# Clamp values
args.summin = max(5, args.summin) args.summin = max(5, args.summin)
args.summax = max(5, args.summax) args.summax = max(5, args.summax)
args.segmax = max(5, min(args.segmax, 500)) args.segmax = max(5, min(args.segmax, 500))
@ -100,5 +109,5 @@ if __name__ == "__main__":
summary = summarize(chunks, args.summin, args.summax) summary = summarize(chunks, args.summin, args.summax)
print(f"\n{summary}\n") print(f"\n{summary}\n")
print(f"* Saving summary to {args.o.__str__()}") print(f"* Saving summary to {args.o.__str__()}")
with args.o.open("w+") as f: with args.o.open("w+") as f: # overwrites
f.write(summary) f.write(summary)