Skip to content
Snippets Groups Projects
Commit c9ada1ad authored by Ubuntu's avatar Ubuntu
Browse files

:sparkles: compute pseudo-perplexity per word

parent cdd78bab
No related branches found
No related tags found
No related merge requests found
import torch
import math
from transformers import AutoTokenizer, AutoModelForMaskedLM
def calculate_pseudo_perplexity(text, model_name):
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
model.eval()
# Tokenize the input sentence
tokens = tokenizer(text, return_tensors="pt", truncation=True)
input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]
word_ids = tokens.word_ids()
# Get tokenized words
words = tokenizer.convert_ids_to_tokens(input_ids[0])
num_tokens = input_ids.size(1)
print(words)
print(num_tokens)
print(tokens)
print(word_ids)
pppl_per_word = []
curr_word_id = None
word_log_prob = 0
token_count = 0
words = []
# Iteratively mask each token
for i in range(1, num_tokens - 1): # Skip [CLS] and [SEP] tokens
masked_input_ids = input_ids.clone()
masked_input_ids[0, i] = tokenizer.mask_token_id # Mask the i-th token
print(i, masked_input_ids)
print(tokenizer.convert_ids_to_tokens(masked_input_ids[0]))
# Predict masked token
with torch.no_grad():
outputs = model(input_ids=masked_input_ids, attention_mask=attention_mask)
logits = outputs.logits
print(logits)
# Compute log probability of the target token
target_token_id = input_ids[0, i]
log_probs = torch.log_softmax(logits[0, i], dim=-1)
print("log prob mask:", log_probs[target_token_id].item())
trg_log_prob = log_probs[target_token_id].item()
# Calculate word-level pseudo-perplexity i.e. aggregate the scores of all tokens corresponding to one word
if word_ids[i] != curr_word_id:
# Finalize the previous word's PPPL
if curr_word_id is not None and token_count > 0:
pppl_per_word.append(math.exp(-word_log_prob / token_count))
# Reset for the new word
curr_word_id = word_ids[i]
word_log_prob = 0
token_count = 0
words.append(tokenizer.convert_ids_to_tokens([input_ids[0, i]])[0])
# Add log probabilities for tokens in the current word
word_log_prob += trg_log_prob
token_count += 1
# Finalize the last word in the batch
if token_count > 0:
pppl_per_word.append(math.exp(-word_log_prob / token_count))
return words, pppl_per_word
def main():
model_name = "bert-base-multilingual-cased"
#example_sent = "Man träffc hochmüthige Leute an, die davor angesehen seyn wollten, sie wären dazu gesetzt, daß sie den Erdboden richten sollten, und dieser Wahn."
#example_sent = "Hallo Welt!"
#example_sent= "Morgen gehen wir ins Kino."
example_sent = "Man träffc hochmüthige Leute an."
#example_sent = "Good morning!"
words, pseudo_perplexity = calculate_pseudo_perplexity(example_sent, model_name)
for word, pppl in zip(words, pseudo_perplexity):
print(word, ":", pppl)
if __name__ == '__main__':
main()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment