diff --git a/compute_pppl.py b/compute_pppl.py new file mode 100644 index 0000000000000000000000000000000000000000..95c2ecb3941f0fc6ba1eff9b580b9df02e2e6d52 --- /dev/null +++ b/compute_pppl.py @@ -0,0 +1,94 @@ +import torch +import math +from transformers import AutoTokenizer, AutoModelForMaskedLM + + + +def calculate_pseudo_perplexity(text, model_name): + # Load the model and tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForMaskedLM.from_pretrained(model_name) + model.eval() + + # Tokenize the input sentence + tokens = tokenizer(text, return_tensors="pt", truncation=True) + input_ids = tokens["input_ids"] + attention_mask = tokens["attention_mask"] + word_ids = tokens.word_ids() + + # Get tokenized words + words = tokenizer.convert_ids_to_tokens(input_ids[0]) + num_tokens = input_ids.size(1) + + print(words) + print(num_tokens) + print(tokens) + print(word_ids) + + + pppl_per_word = [] + curr_word_id = None + word_log_prob = 0 + token_count = 0 + words = [] + + # Iteratively mask each token + for i in range(1, num_tokens - 1): # Skip [CLS] and [SEP] tokens + masked_input_ids = input_ids.clone() + masked_input_ids[0, i] = tokenizer.mask_token_id # Mask the i-th token + print(i, masked_input_ids) + print(tokenizer.convert_ids_to_tokens(masked_input_ids[0])) + + # Predict masked token + with torch.no_grad(): + outputs = model(input_ids=masked_input_ids, attention_mask=attention_mask) + logits = outputs.logits + + print(logits) + + # Compute log probability of the target token + target_token_id = input_ids[0, i] + log_probs = torch.log_softmax(logits[0, i], dim=-1) + print("log prob mask:", log_probs[target_token_id].item()) + trg_log_prob = log_probs[target_token_id].item() + + # Calculate word-level pseudo-perplexity i.e. aggregate the scores of all tokens corresponding to one word + if word_ids[i] != curr_word_id: + # Finalize the previous word's PPPL + if curr_word_id is not None and token_count > 0: + pppl_per_word.append(math.exp(-word_log_prob / token_count)) + + # Reset for the new word + curr_word_id = word_ids[i] + word_log_prob = 0 + token_count = 0 + words.append(tokenizer.convert_ids_to_tokens([input_ids[0, i]])[0]) + + # Add log probabilities for tokens in the current word + word_log_prob += trg_log_prob + token_count += 1 + + # Finalize the last word in the batch + if token_count > 0: + pppl_per_word.append(math.exp(-word_log_prob / token_count)) + + return words, pppl_per_word + + + + + +def main(): + model_name = "bert-base-multilingual-cased" + #example_sent = "Man träffc hochmüthige Leute an, die davor angesehen seyn wollten, sie wären dazu gesetzt, daß sie den Erdboden richten sollten, und dieser Wahn." + #example_sent = "Hallo Welt!" + #example_sent= "Morgen gehen wir ins Kino." + example_sent = "Man träffc hochmüthige Leute an." + #example_sent = "Good morning!" + words, pseudo_perplexity = calculate_pseudo_perplexity(example_sent, model_name) + for word, pppl in zip(words, pseudo_perplexity): + print(word, ":", pppl) + + +if __name__ == '__main__': + main() \ No newline at end of file