From c9ada1ad903ec9c1fdef71ce12559a7f5f2da15c Mon Sep 17 00:00:00 2001
From: Ubuntu
 <zbzscript@zbgpu01.ndj4anicnnlexeer5gqrjnt1hc.ax.internal.cloudapp.net>
Date: Mon, 20 Jan 2025 14:49:35 +0000
Subject: [PATCH] :sparkles: compute pseudo-perplexity per word

---
 compute_pppl.py | 94 +++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 94 insertions(+)
 create mode 100644 compute_pppl.py

diff --git a/compute_pppl.py b/compute_pppl.py
new file mode 100644
index 0000000..95c2ecb
--- /dev/null
+++ b/compute_pppl.py
@@ -0,0 +1,94 @@
+import torch
+import math
+from transformers import AutoTokenizer, AutoModelForMaskedLM
+
+
+
+def calculate_pseudo_perplexity(text, model_name):
+    # Load the model and tokenizer
+    tokenizer = AutoTokenizer.from_pretrained(model_name)
+    model = AutoModelForMaskedLM.from_pretrained(model_name)
+    model.eval()
+
+    # Tokenize the input sentence
+    tokens = tokenizer(text, return_tensors="pt", truncation=True)
+    input_ids = tokens["input_ids"]
+    attention_mask = tokens["attention_mask"]
+    word_ids = tokens.word_ids()
+    
+    # Get tokenized words
+    words = tokenizer.convert_ids_to_tokens(input_ids[0])
+    num_tokens = input_ids.size(1)
+
+    print(words)
+    print(num_tokens)
+    print(tokens)
+    print(word_ids)
+
+
+    pppl_per_word = []
+    curr_word_id = None
+    word_log_prob = 0
+    token_count = 0
+    words = []
+
+    # Iteratively mask each token
+    for i in range(1, num_tokens - 1):  # Skip [CLS] and [SEP] tokens
+        masked_input_ids = input_ids.clone()
+        masked_input_ids[0, i] = tokenizer.mask_token_id  # Mask the i-th token
+        print(i, masked_input_ids)
+        print(tokenizer.convert_ids_to_tokens(masked_input_ids[0]))
+
+        # Predict masked token
+        with torch.no_grad():
+            outputs = model(input_ids=masked_input_ids, attention_mask=attention_mask)
+            logits = outputs.logits
+
+            print(logits)
+        
+        # Compute log probability of the target token
+        target_token_id = input_ids[0, i]
+        log_probs = torch.log_softmax(logits[0, i], dim=-1)
+        print("log prob mask:", log_probs[target_token_id].item())
+        trg_log_prob = log_probs[target_token_id].item()
+
+        # Calculate word-level pseudo-perplexity i.e. aggregate the scores of all tokens corresponding to one word
+        if word_ids[i] != curr_word_id:
+            # Finalize the previous word's PPPL
+            if curr_word_id is not None and token_count > 0:
+                pppl_per_word.append(math.exp(-word_log_prob / token_count))
+
+            # Reset for the new word
+            curr_word_id = word_ids[i]
+            word_log_prob = 0
+            token_count = 0
+            words.append(tokenizer.convert_ids_to_tokens([input_ids[0, i]])[0])
+        
+        # Add log probabilities for tokens in the current word
+        word_log_prob += trg_log_prob
+        token_count += 1
+
+    # Finalize the last word in the batch
+    if token_count > 0:
+        pppl_per_word.append(math.exp(-word_log_prob / token_count))
+    
+    return words, pppl_per_word
+
+
+
+
+
+def main():
+    model_name = "bert-base-multilingual-cased"
+    #example_sent = "Man träffc hochmüthige Leute an, die davor angesehen seyn wollten, sie wären dazu gesetzt, daß sie den Erdboden richten sollten, und dieser Wahn."
+    #example_sent = "Hallo Welt!"
+    #example_sent= "Morgen gehen wir ins Kino."
+    example_sent = "Man träffc hochmüthige Leute an."
+    #example_sent = "Good morning!"
+    words, pseudo_perplexity = calculate_pseudo_perplexity(example_sent, model_name)
+    for word, pppl in zip(words, pseudo_perplexity):
+        print(word, ":", pppl)
+
+
+if __name__ == '__main__':
+    main()
\ No newline at end of file
-- 
GitLab