From 065188579e54dbdfdf70e217233ac041c8421b3b Mon Sep 17 00:00:00 2001
From: Ubuntu
 <zbzscript@zbgpu01.ndj4anicnnlexeer5gqrjnt1hc.ax.internal.cloudapp.net>
Date: Fri, 31 Jan 2025 11:07:16 +0000
Subject: [PATCH] :zap: restructure code and add batch processing

---
 compute_pppl.py | 223 ++++++++++++++++++++++++++++++++++--------------
 1 file changed, 158 insertions(+), 65 deletions(-)

diff --git a/compute_pppl.py b/compute_pppl.py
index 7498f06..11b6efd 100644
--- a/compute_pppl.py
+++ b/compute_pppl.py
@@ -1,37 +1,58 @@
+import os
 import torch
 import math
+import json
+import logging
+import argparse
+from tqdm import tqdm
 from transformers import AutoTokenizer, AutoModelForMaskedLM
 
 
+# check if GPU is available
+print("CUDA available:", torch.cuda.is_available())
+print("Device count:", torch.cuda.device_count())
+print("Current device:", torch.cuda.current_device())
+
+# Configure logging
+logging.basicConfig(level=logging.INFO)
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Compute pseudo-perplexity for each word in a text.")
+    parser.add_argument("-m", "--model-name", type=str, default="bert-base-multilingual-uncased", help="Model name or path")
+    parser.add_argument("-i", "--input-path", type=str, default="data/sentences", help="Path to the input directory")
+    parser.add_argument("-o", "--output-path", type=str, default="data/pppl_per_sent", help="Path to the output directory")
+    parser.add_argument("-w", "--window-size", type=int, default=11, help="Size of the sliding window")
+    parser.add_argument("-b", "--batch-size", type=int, default=32, help="Batch size for processing")
+    return parser.parse_args()
 
-def calculate_pseudo_perplexity(text, model_name, window_size):
-    # Load the model and tokenizer
-    tokenizer = AutoTokenizer.from_pretrained(model_name)
-    model = AutoModelForMaskedLM.from_pretrained(model_name)
-    model.eval()
 
-    # Tokenize the input sentence
-    tokens = tokenizer(text, return_tensors="pt", truncation=True)
+def prepare_sliding_windows(text, tokenizer, window_size):
+    trg_pos = window_size // 2
+    windows = {"masked_input_ids": [], "attention_mask": [], "true_token_ids": [], "word_ids": [], "words": []}
+
+    tokens = tokenizer(text, return_tensors="pt", truncation=False)
     input_ids = tokens["input_ids"]
     attention_mask = tokens["attention_mask"]
-    word_ids = tokens.word_ids()
-    
-    # Get tokenized words
-    words = tokenizer.convert_ids_to_tokens(input_ids[0])
+
+    # Get tokenized readable words
+    converted_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
     num_tokens = input_ids.size(1)
-    trg_pos = window_size // 2
 
-    #print(words)
-    #print(num_tokens)
-    #print(tokens)
-    #print(word_ids)
+    # Ensure that the word IDs correspond to the original words (white-space seperated) in the text
+    word_ids = [None] # ID for the [CLS] token
+    current_word_id = 0
+    words = text.split()
 
+    # Assign customized word IDs
+    for word in words:
+        # Tokenize each word separately
+        word_tokens = tokenizer.tokenize(word)
+        # Assign the same word ID to all tokens of the current word
+        word_ids.extend([current_word_id] * len(word_tokens))
+        current_word_id += 1
+    word_ids.append(None)  # ID for the [SEP] token
 
-    pppl_per_word = []
-    curr_word_id = None
-    word_log_prob = 0
-    token_count = 0
-    words = []
 
     # Iteratively mask each token
     for i in range(1, num_tokens - 1):  # Skip [CLS] and [SEP] tokens
@@ -43,16 +64,11 @@ def calculate_pseudo_perplexity(text, model_name, window_size):
         # Extract the window
         window_ids = input_ids[0, start_idx:end_idx]
         window_attention_mask = attention_mask[0, start_idx:end_idx]
-        #print("WID:", window_ids)
-        #print("WAM:", window_attention_mask)
 
         # Pad the window if it's smaller than window_size
         if end_idx - start_idx < window_size:
             padding_left = max(0, trg_pos - i + 1)
-            #print("pos:", trg_pos, "i:", i, "num_tokens:", num_tokens)
-            #print(padding_left)
             padding_right = max(0, trg_pos - ((num_tokens - 2) - i))
-            #print(padding_right)
             window_ids = torch.cat(
                 [
                     input_ids[0, :1].repeat(padding_left),  # [CLS] token for padding
@@ -68,70 +84,147 @@ def calculate_pseudo_perplexity(text, model_name, window_size):
                 ]
             )
 
-
-        #print("WID:", window_ids)
-        #print("WAM:", window_attention_mask)
-
         # Mask the middle token in the window
         masked_window_ids = window_ids.clone()
         masked_window_ids[trg_pos] = tokenizer.mask_token_id
-        print(i, masked_window_ids)
-        print(tokenizer.convert_ids_to_tokens(masked_window_ids))
 
-        # Predict masked token
+        # Store the window, original token, attention mask, word IDs and tokens
+        windows["masked_input_ids"].append(masked_window_ids.tolist())
+        windows["attention_mask"].append(window_attention_mask.tolist())
+        windows["true_token_ids"].append(window_ids[trg_pos].item())
+    windows["word_ids"].extend(word_ids)
+    windows["words"].extend(converted_tokens)
+
+    return windows
+
+
+def calculate_pseudo_perplexity(text, model_name, window_size=11, batch_size=32):
+    # Load the model and tokenizer
+    tokenizer = AutoTokenizer.from_pretrained(model_name)
+    model = AutoModelForMaskedLM.from_pretrained(model_name)
+    model.eval()
+
+    trg_pos = window_size // 2
+    pppl_per_word = []
+
+    # Prepare sliding windows for the current text
+    windows = prepare_sliding_windows(text, tokenizer, window_size)
+
+    # Convert to PyTorch tensors
+    masked_input_ids = torch.tensor(windows["masked_input_ids"])
+    attention_mask = torch.tensor(windows["attention_mask"])
+    true_token_ids = windows["true_token_ids"]
+    word_ids = windows["word_ids"]
+
+    # Process in batches
+    all_token_log_probs = []
+    for start_idx in tqdm(range(0, len(masked_input_ids), batch_size), desc="Processing batches"):
+        batch_input_ids = masked_input_ids[start_idx : start_idx + batch_size]
+        batch_attention_mask = attention_mask[start_idx : start_idx + batch_size]
+        batch_true_token_ids = true_token_ids[start_idx : start_idx + batch_size]
+
+        # Predict masked tokens
         with torch.no_grad():
-            outputs = model(input_ids=masked_window_ids.unsqueeze(0), attention_mask=window_attention_mask.unsqueeze(0))
+            outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
             logits = outputs.logits
 
-            print(logits)
-        
-        # Compute log probability of the target token
-        target_token_id = window_ids[trg_pos]
-        log_probs = torch.log_softmax(logits[0, trg_pos], dim=-1)
-        print("log prob mask:", log_probs[target_token_id].item())
-        trg_log_prob = log_probs[target_token_id].item()
-
-        # Calculate word-level pseudo-perplexity i.e. aggregate the scores of all tokens corresponding to one word
-        if word_ids[i] != curr_word_id:
+        # Compute log probability of the target tokens
+        log_probs = torch.log_softmax(logits[:, trg_pos, :], dim=-1)
+        token_log_probs = (log_probs[range(len(batch_true_token_ids)), batch_true_token_ids])
+        all_token_log_probs.extend(token_log_probs.tolist())
+
+
+    # Calculate word-level pseudo-perplexity i.e. aggregate the scores of all tokens corresponding to one word
+    current_word_id = None
+    word_log_prob = 0
+    token_count = 0
+    words = []
+    target_words = word_ids[1:-1]
+
+    for i in range(len(target_words)):
+        if target_words[i] != current_word_id:
             # Finalize the previous word's PPPL
-            if curr_word_id is not None and token_count > 0:
+            if current_word_id is not None and token_count > 0:
                 pppl_per_word.append(math.exp(-word_log_prob / token_count))
 
             # Reset for the new word
-            curr_word_id = word_ids[i]
+            current_word_id = target_words[i]
             word_log_prob = 0
             token_count = 0
-            words.append(tokenizer.convert_ids_to_tokens([input_ids[0, i]])[0])
-        
+            words.append(tokenizer.convert_ids_to_tokens(true_token_ids[i]))
+            
         # Add log probabilities for tokens in the current word
-        word_log_prob += trg_log_prob
-        token_count += 1
+        if target_words[i] is not None:  # Ignore padding tokens
+            word_log_prob += all_token_log_probs[i]
+            token_count += 1
 
     # Finalize the last word in the batch
     if token_count > 0:
         pppl_per_word.append(math.exp(-word_log_prob / token_count))
+
+
+    if len(pppl_per_word) != len(words):
+        logging.error("The lengths of the list of words and the list of pseudo-perplexity scores are not equal.")
+        raise ValueError("The lengths of the list of words and the list of pseudo-perplexity scores are not equal.")
     
-    return words, pppl_per_word
+    else:
+        logging.info("Valid output: The lengths of the list of words and the list of pseudo-perplexity scores are equal.")
+
 
+    for word, pppl in zip(words, pppl_per_word):
+        print(word, ":", pppl)
 
+    return pppl_per_word
 
 
 
-def main():
-    model_name = "bert-base-multilingual-cased"
-    window_size =11
-    #example_sent = "Man träffc hochmüthige Leute an, die davor angesehen seyn wollten, sie wären dazu gesetzt, daß sie den Erdboden richten sollten, und dieser Wahn."
-    #example_sent = "Hallo Welt!"
-    #example_sent= "Mofgen gehen win ims Kino und alle bekomen eine Tute Popcorn, weil si heute brv die Hausaufgben gemact haben."
-    example_sent= "Morgen gehen wir ins Kino und alle bekommen eine Tüte Popcorn, weil sie heute brav die Hausaufgaben gemacht haben."
-    #example_sent = "Ein grosser Baum steht vor dem Haus."
+def read_json(json_file):
+    text = []
+    json_dict = json.load(json_file)
+    for page in json_dict:
+        for word in page["content"]:
+            text.append(word["word"])
+    text = " ".join(text)
+    return json_dict, text
 
-    #example_sent = "Man träffc hochmüthige Leute an , die davor angesehen seyn wollten ."
-    #example_sent = "Good morning!"
-    words, pseudo_perplexity = calculate_pseudo_perplexity(example_sent, model_name, window_size)
-    for word, pppl in zip(words, pseudo_perplexity):
-        print(word, ":", pppl)
+
+
+def main(args):
+    model_name = args.model_name
+    window_size = args.window_size
+    batch_size = args.batch_size
+
+
+    for (root,dirs,files) in os.walk(args.input_path, topdown=True):
+        for filename in files:
+        
+            # Read the JSON file
+            with open(f"{root}/{filename}", "r", encoding="utf-8") as json_file:
+                json_dict, text = read_json(json_file)
+
+            # Compute pseudo-perplexity for each word in the text
+            pppl_per_word = calculate_pseudo_perplexity(text, model_name, window_size, batch_size)
+
+            # Check if output is valid
+            if len(pppl_per_word) == len(text.split()):
+                logging.info("Valid output: The lengths of the text and the list of pseudo-perplexity scores are equal.")
+            else:
+                print(len(pppl_per_word), len(text.split()))
+                logging.error("The lengths of the text and the list of pseudo-perplexity scores are not equal.")
+                raise ValueError("The lengths of the text and the list of pseudo-perplexity scores are not equal.")
+            
+
+            # Add the pseudo-perplexity scores to the JSON dictionary and write it to the output file
+            score_index = 0
+            for page in json_dict:
+                for word in page["content"]:
+                    word["pppl"] = pppl_per_word[score_index]
+                    score_index += 1
+
+            with open(f"{args.output_path}/{filename[:-5]}_pppl.json", "w", encoding="utf-8") as json_file:
+                json.dump(json_dict, json_file, ensure_ascii=False, indent=4)
 
 
 if __name__ == '__main__':
-    main()
\ No newline at end of file
+    args = parse_args()
+    main(args)
\ No newline at end of file
-- 
GitLab