Skip to content
Snippets Groups Projects
Commit 7a4c1649 authored by Ubuntu's avatar Ubuntu
Browse files

compute pseudo-perplexity with the lmppl repo

parent a81a8391
No related branches found
No related tags found
No related merge requests found
import os
import json
import logging
import argparse
from pathlib import Path
from lmppl.ppl_mlm import MaskedLM
from transformers import AutoTokenizer, AutoModelForMaskedLM
# Configure logging
logging.basicConfig(level=logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(description="Run pseudo-perplexity calculation")
parser.add_argument("-m", "--model-name", type=str, default="bert-base-multilingual-uncased", help="Model name")
parser.add_argument("-i", "--input-path", type=str, default="data/sentences", help="Path to the input directory")
parser.add_argument("-o", "--output-path", type=str, default="data/pppl_per_sent", help="Path to the output directory")
parser.add_argument("-b", "--batch-size", type=int, default=32, help="Batch size")
parser.add_argument("-l", "--max-length", type=int, default=512, help="Max length of the input sequence to BERT")
return parser.parse_args()
def read_json(json_file):
text = []
json_dict = json.load(json_file)
for sent in json_dict:
text.append(sent["ocr"])
return json_dict, text
def check_text_length(text, model_name):
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Check if the length of the input text is within the model's limit
for sent in text:
tokens = tokenizer(sent, return_tensors="pt", truncation=False)
if len(tokens["input_ids"][0]) > 512:
logging.warning(f"Sentence length is {len(tokens['input_ids'][0])}")
return False
return True
def main(args):
model_name = args.model_name
Path(args.output_path).mkdir(parents=True, exist_ok=True)
# Load input data and check if the length of the input text is within the model's limit
for (root,dirs,files) in os.walk(args.input_path, topdown=True):
for filename in files:
with open(f"{root}/{filename}", "r", encoding="utf-8") as json_file:
json_dict, text = read_json(json_file)
length_check = check_text_length(text, model_name)
if not length_check:
logging.error(f"{filename}: length check failed")
else:
logging.info(f"{filename}: length check passed")
# Calcualte the pseudo-perplexity per sentence
scorer = MaskedLM(model_name, max_length=args.max_length)
pppl = scorer.get_perplexity(text, batch_size=args.batch_size)
# Add the pseudo-perplexity to the json file
for i, sent in enumerate(json_dict):
sent["pppl_uncased"] = pppl[i]
with open(f"{args.output_path}/{filename[:-5]}_pppl.json", "w", encoding="utf-8") as json_file:
json.dump(json_dict, json_file, ensure_ascii=False, indent=4)
if __name__ == "__main__":
args = parse_args()
main(args)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment