From 065188579e54dbdfdf70e217233ac041c8421b3b Mon Sep 17 00:00:00 2001 From: Ubuntu <zbzscript@zbgpu01.ndj4anicnnlexeer5gqrjnt1hc.ax.internal.cloudapp.net> Date: Fri, 31 Jan 2025 11:07:16 +0000 Subject: [PATCH] :zap: restructure code and add batch processing --- compute_pppl.py | 223 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 158 insertions(+), 65 deletions(-) diff --git a/compute_pppl.py b/compute_pppl.py index 7498f06..11b6efd 100644 --- a/compute_pppl.py +++ b/compute_pppl.py @@ -1,37 +1,58 @@ +import os import torch import math +import json +import logging +import argparse +from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForMaskedLM +# check if GPU is available +print("CUDA available:", torch.cuda.is_available()) +print("Device count:", torch.cuda.device_count()) +print("Current device:", torch.cuda.current_device()) + +# Configure logging +logging.basicConfig(level=logging.INFO) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Compute pseudo-perplexity for each word in a text.") + parser.add_argument("-m", "--model-name", type=str, default="bert-base-multilingual-uncased", help="Model name or path") + parser.add_argument("-i", "--input-path", type=str, default="data/sentences", help="Path to the input directory") + parser.add_argument("-o", "--output-path", type=str, default="data/pppl_per_sent", help="Path to the output directory") + parser.add_argument("-w", "--window-size", type=int, default=11, help="Size of the sliding window") + parser.add_argument("-b", "--batch-size", type=int, default=32, help="Batch size for processing") + return parser.parse_args() -def calculate_pseudo_perplexity(text, model_name, window_size): - # Load the model and tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForMaskedLM.from_pretrained(model_name) - model.eval() - # Tokenize the input sentence - tokens = tokenizer(text, return_tensors="pt", truncation=True) +def prepare_sliding_windows(text, tokenizer, window_size): + trg_pos = window_size // 2 + windows = {"masked_input_ids": [], "attention_mask": [], "true_token_ids": [], "word_ids": [], "words": []} + + tokens = tokenizer(text, return_tensors="pt", truncation=False) input_ids = tokens["input_ids"] attention_mask = tokens["attention_mask"] - word_ids = tokens.word_ids() - - # Get tokenized words - words = tokenizer.convert_ids_to_tokens(input_ids[0]) + + # Get tokenized readable words + converted_tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) num_tokens = input_ids.size(1) - trg_pos = window_size // 2 - #print(words) - #print(num_tokens) - #print(tokens) - #print(word_ids) + # Ensure that the word IDs correspond to the original words (white-space seperated) in the text + word_ids = [None] # ID for the [CLS] token + current_word_id = 0 + words = text.split() + # Assign customized word IDs + for word in words: + # Tokenize each word separately + word_tokens = tokenizer.tokenize(word) + # Assign the same word ID to all tokens of the current word + word_ids.extend([current_word_id] * len(word_tokens)) + current_word_id += 1 + word_ids.append(None) # ID for the [SEP] token - pppl_per_word = [] - curr_word_id = None - word_log_prob = 0 - token_count = 0 - words = [] # Iteratively mask each token for i in range(1, num_tokens - 1): # Skip [CLS] and [SEP] tokens @@ -43,16 +64,11 @@ def calculate_pseudo_perplexity(text, model_name, window_size): # Extract the window window_ids = input_ids[0, start_idx:end_idx] window_attention_mask = attention_mask[0, start_idx:end_idx] - #print("WID:", window_ids) - #print("WAM:", window_attention_mask) # Pad the window if it's smaller than window_size if end_idx - start_idx < window_size: padding_left = max(0, trg_pos - i + 1) - #print("pos:", trg_pos, "i:", i, "num_tokens:", num_tokens) - #print(padding_left) padding_right = max(0, trg_pos - ((num_tokens - 2) - i)) - #print(padding_right) window_ids = torch.cat( [ input_ids[0, :1].repeat(padding_left), # [CLS] token for padding @@ -68,70 +84,147 @@ def calculate_pseudo_perplexity(text, model_name, window_size): ] ) - - #print("WID:", window_ids) - #print("WAM:", window_attention_mask) - # Mask the middle token in the window masked_window_ids = window_ids.clone() masked_window_ids[trg_pos] = tokenizer.mask_token_id - print(i, masked_window_ids) - print(tokenizer.convert_ids_to_tokens(masked_window_ids)) - # Predict masked token + # Store the window, original token, attention mask, word IDs and tokens + windows["masked_input_ids"].append(masked_window_ids.tolist()) + windows["attention_mask"].append(window_attention_mask.tolist()) + windows["true_token_ids"].append(window_ids[trg_pos].item()) + windows["word_ids"].extend(word_ids) + windows["words"].extend(converted_tokens) + + return windows + + +def calculate_pseudo_perplexity(text, model_name, window_size=11, batch_size=32): + # Load the model and tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForMaskedLM.from_pretrained(model_name) + model.eval() + + trg_pos = window_size // 2 + pppl_per_word = [] + + # Prepare sliding windows for the current text + windows = prepare_sliding_windows(text, tokenizer, window_size) + + # Convert to PyTorch tensors + masked_input_ids = torch.tensor(windows["masked_input_ids"]) + attention_mask = torch.tensor(windows["attention_mask"]) + true_token_ids = windows["true_token_ids"] + word_ids = windows["word_ids"] + + # Process in batches + all_token_log_probs = [] + for start_idx in tqdm(range(0, len(masked_input_ids), batch_size), desc="Processing batches"): + batch_input_ids = masked_input_ids[start_idx : start_idx + batch_size] + batch_attention_mask = attention_mask[start_idx : start_idx + batch_size] + batch_true_token_ids = true_token_ids[start_idx : start_idx + batch_size] + + # Predict masked tokens with torch.no_grad(): - outputs = model(input_ids=masked_window_ids.unsqueeze(0), attention_mask=window_attention_mask.unsqueeze(0)) + outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask) logits = outputs.logits - print(logits) - - # Compute log probability of the target token - target_token_id = window_ids[trg_pos] - log_probs = torch.log_softmax(logits[0, trg_pos], dim=-1) - print("log prob mask:", log_probs[target_token_id].item()) - trg_log_prob = log_probs[target_token_id].item() - - # Calculate word-level pseudo-perplexity i.e. aggregate the scores of all tokens corresponding to one word - if word_ids[i] != curr_word_id: + # Compute log probability of the target tokens + log_probs = torch.log_softmax(logits[:, trg_pos, :], dim=-1) + token_log_probs = (log_probs[range(len(batch_true_token_ids)), batch_true_token_ids]) + all_token_log_probs.extend(token_log_probs.tolist()) + + + # Calculate word-level pseudo-perplexity i.e. aggregate the scores of all tokens corresponding to one word + current_word_id = None + word_log_prob = 0 + token_count = 0 + words = [] + target_words = word_ids[1:-1] + + for i in range(len(target_words)): + if target_words[i] != current_word_id: # Finalize the previous word's PPPL - if curr_word_id is not None and token_count > 0: + if current_word_id is not None and token_count > 0: pppl_per_word.append(math.exp(-word_log_prob / token_count)) # Reset for the new word - curr_word_id = word_ids[i] + current_word_id = target_words[i] word_log_prob = 0 token_count = 0 - words.append(tokenizer.convert_ids_to_tokens([input_ids[0, i]])[0]) - + words.append(tokenizer.convert_ids_to_tokens(true_token_ids[i])) + # Add log probabilities for tokens in the current word - word_log_prob += trg_log_prob - token_count += 1 + if target_words[i] is not None: # Ignore padding tokens + word_log_prob += all_token_log_probs[i] + token_count += 1 # Finalize the last word in the batch if token_count > 0: pppl_per_word.append(math.exp(-word_log_prob / token_count)) + + + if len(pppl_per_word) != len(words): + logging.error("The lengths of the list of words and the list of pseudo-perplexity scores are not equal.") + raise ValueError("The lengths of the list of words and the list of pseudo-perplexity scores are not equal.") - return words, pppl_per_word + else: + logging.info("Valid output: The lengths of the list of words and the list of pseudo-perplexity scores are equal.") + + for word, pppl in zip(words, pppl_per_word): + print(word, ":", pppl) + return pppl_per_word -def main(): - model_name = "bert-base-multilingual-cased" - window_size =11 - #example_sent = "Man träffc hochmüthige Leute an, die davor angesehen seyn wollten, sie wären dazu gesetzt, daß sie den Erdboden richten sollten, und dieser Wahn." - #example_sent = "Hallo Welt!" - #example_sent= "Mofgen gehen win ims Kino und alle bekomen eine Tute Popcorn, weil si heute brv die Hausaufgben gemact haben." - example_sent= "Morgen gehen wir ins Kino und alle bekommen eine Tüte Popcorn, weil sie heute brav die Hausaufgaben gemacht haben." - #example_sent = "Ein grosser Baum steht vor dem Haus." +def read_json(json_file): + text = [] + json_dict = json.load(json_file) + for page in json_dict: + for word in page["content"]: + text.append(word["word"]) + text = " ".join(text) + return json_dict, text - #example_sent = "Man träffc hochmüthige Leute an , die davor angesehen seyn wollten ." - #example_sent = "Good morning!" - words, pseudo_perplexity = calculate_pseudo_perplexity(example_sent, model_name, window_size) - for word, pppl in zip(words, pseudo_perplexity): - print(word, ":", pppl) + + +def main(args): + model_name = args.model_name + window_size = args.window_size + batch_size = args.batch_size + + + for (root,dirs,files) in os.walk(args.input_path, topdown=True): + for filename in files: + + # Read the JSON file + with open(f"{root}/{filename}", "r", encoding="utf-8") as json_file: + json_dict, text = read_json(json_file) + + # Compute pseudo-perplexity for each word in the text + pppl_per_word = calculate_pseudo_perplexity(text, model_name, window_size, batch_size) + + # Check if output is valid + if len(pppl_per_word) == len(text.split()): + logging.info("Valid output: The lengths of the text and the list of pseudo-perplexity scores are equal.") + else: + print(len(pppl_per_word), len(text.split())) + logging.error("The lengths of the text and the list of pseudo-perplexity scores are not equal.") + raise ValueError("The lengths of the text and the list of pseudo-perplexity scores are not equal.") + + + # Add the pseudo-perplexity scores to the JSON dictionary and write it to the output file + score_index = 0 + for page in json_dict: + for word in page["content"]: + word["pppl"] = pppl_per_word[score_index] + score_index += 1 + + with open(f"{args.output_path}/{filename[:-5]}_pppl.json", "w", encoding="utf-8") as json_file: + json.dump(json_dict, json_file, ensure_ascii=False, indent=4) if __name__ == '__main__': - main() \ No newline at end of file + args = parse_args() + main(args) \ No newline at end of file -- GitLab