From 7a4c164937f358bc7c5d5ed5e2591df38fece36f Mon Sep 17 00:00:00 2001 From: Ubuntu <zbzscript@zbgpu01.ndj4anicnnlexeer5gqrjnt1hc.ax.internal.cloudapp.net> Date: Fri, 31 Jan 2025 11:48:16 +0000 Subject: [PATCH] :sparkles: compute pseudo-perplexity with the lmppl repo --- run_lmppl.py | 81 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 run_lmppl.py diff --git a/run_lmppl.py b/run_lmppl.py new file mode 100644 index 0000000..0824af2 --- /dev/null +++ b/run_lmppl.py @@ -0,0 +1,81 @@ +import os +import json +import logging +import argparse +from pathlib import Path +from lmppl.ppl_mlm import MaskedLM +from transformers import AutoTokenizer, AutoModelForMaskedLM + + +# Configure logging +logging.basicConfig(level=logging.INFO) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run pseudo-perplexity calculation") + parser.add_argument("-m", "--model-name", type=str, default="bert-base-multilingual-uncased", help="Model name") + parser.add_argument("-i", "--input-path", type=str, default="data/sentences", help="Path to the input directory") + parser.add_argument("-o", "--output-path", type=str, default="data/pppl_per_sent", help="Path to the output directory") + parser.add_argument("-b", "--batch-size", type=int, default=32, help="Batch size") + parser.add_argument("-l", "--max-length", type=int, default=512, help="Max length of the input sequence to BERT") + return parser.parse_args() + + + +def read_json(json_file): + text = [] + json_dict = json.load(json_file) + for sent in json_dict: + text.append(sent["ocr"]) + return json_dict, text + + + +def check_text_length(text, model_name): + # Load the tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_name) + + # Check if the length of the input text is within the model's limit + for sent in text: + tokens = tokenizer(sent, return_tensors="pt", truncation=False) + if len(tokens["input_ids"][0]) > 512: + logging.warning(f"Sentence length is {len(tokens['input_ids'][0])}") + return False + return True + + + +def main(args): + model_name = args.model_name + Path(args.output_path).mkdir(parents=True, exist_ok=True) + + # Load input data and check if the length of the input text is within the model's limit + for (root,dirs,files) in os.walk(args.input_path, topdown=True): + for filename in files: + + with open(f"{root}/{filename}", "r", encoding="utf-8") as json_file: + json_dict, text = read_json(json_file) + + length_check = check_text_length(text, model_name) + if not length_check: + logging.error(f"{filename}: length check failed") + + else: + logging.info(f"{filename}: length check passed") + + # Calcualte the pseudo-perplexity per sentence + scorer = MaskedLM(model_name, max_length=args.max_length) + pppl = scorer.get_perplexity(text, batch_size=args.batch_size) + + # Add the pseudo-perplexity to the json file + for i, sent in enumerate(json_dict): + sent["pppl_uncased"] = pppl[i] + + with open(f"{args.output_path}/{filename[:-5]}_pppl.json", "w", encoding="utf-8") as json_file: + json.dump(json_dict, json_file, ensure_ascii=False, indent=4) + + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file -- GitLab