From 7a4c164937f358bc7c5d5ed5e2591df38fece36f Mon Sep 17 00:00:00 2001
From: Ubuntu
 <zbzscript@zbgpu01.ndj4anicnnlexeer5gqrjnt1hc.ax.internal.cloudapp.net>
Date: Fri, 31 Jan 2025 11:48:16 +0000
Subject: [PATCH] :sparkles: compute pseudo-perplexity with the lmppl repo

---
 run_lmppl.py | 81 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 81 insertions(+)
 create mode 100644 run_lmppl.py

diff --git a/run_lmppl.py b/run_lmppl.py
new file mode 100644
index 0000000..0824af2
--- /dev/null
+++ b/run_lmppl.py
@@ -0,0 +1,81 @@
+import os
+import json
+import logging
+import argparse
+from pathlib import Path
+from lmppl.ppl_mlm import MaskedLM 
+from transformers import AutoTokenizer, AutoModelForMaskedLM
+
+
+# Configure logging
+logging.basicConfig(level=logging.INFO)
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Run pseudo-perplexity calculation")
+    parser.add_argument("-m", "--model-name", type=str, default="bert-base-multilingual-uncased", help="Model name")
+    parser.add_argument("-i", "--input-path", type=str, default="data/sentences", help="Path to the input directory")
+    parser.add_argument("-o", "--output-path", type=str, default="data/pppl_per_sent", help="Path to the output directory")
+    parser.add_argument("-b", "--batch-size", type=int, default=32, help="Batch size")
+    parser.add_argument("-l", "--max-length", type=int, default=512, help="Max length of the input sequence to BERT")
+    return parser.parse_args()
+
+
+
+def read_json(json_file):
+    text = []
+    json_dict = json.load(json_file)
+    for sent in json_dict:
+        text.append(sent["ocr"])
+    return json_dict, text
+
+
+
+def check_text_length(text, model_name):
+    # Load the tokenizer
+    tokenizer = AutoTokenizer.from_pretrained(model_name)
+
+    # Check if the length of the input text is within the model's limit
+    for sent in text:
+        tokens = tokenizer(sent, return_tensors="pt", truncation=False)
+        if len(tokens["input_ids"][0]) > 512:
+            logging.warning(f"Sentence length is {len(tokens['input_ids'][0])}")
+            return False
+    return True
+
+
+
+def main(args):
+    model_name = args.model_name
+    Path(args.output_path).mkdir(parents=True, exist_ok=True)
+
+    # Load input data and check if the length of the input text is within the model's limit
+    for (root,dirs,files) in os.walk(args.input_path, topdown=True):
+        for filename in files:
+
+            with open(f"{root}/{filename}", "r", encoding="utf-8") as json_file:
+                json_dict, text = read_json(json_file)
+
+            length_check = check_text_length(text, model_name)
+            if not length_check:
+                logging.error(f"{filename}: length check failed")
+
+            else:
+                logging.info(f"{filename}: length check passed")
+
+                # Calcualte the pseudo-perplexity per sentence
+                scorer = MaskedLM(model_name, max_length=args.max_length)
+                pppl = scorer.get_perplexity(text, batch_size=args.batch_size)
+
+                # Add the pseudo-perplexity to the json file
+                for i, sent in enumerate(json_dict):
+                    sent["pppl_uncased"] = pppl[i]
+
+                with open(f"{args.output_path}/{filename[:-5]}_pppl.json", "w", encoding="utf-8") as json_file:
+                            json.dump(json_dict, json_file, ensure_ascii=False, indent=4)
+
+
+
+if __name__ == "__main__":
+    args = parse_args()
+    main(args)  
\ No newline at end of file
-- 
GitLab