From 3577811abaf172996b732db7f42055a897b5c94f Mon Sep 17 00:00:00 2001 From: Ubuntu <zbzscript@zbgpu01.ndj4anicnnlexeer5gqrjnt1hc.ax.internal.cloudapp.net> Date: Tue, 21 Jan 2025 13:15:00 +0000 Subject: [PATCH] :sparkles: implement sliding window --- compute_pppl.py | 73 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 58 insertions(+), 15 deletions(-) diff --git a/compute_pppl.py b/compute_pppl.py index 95c2ecb..7498f06 100644 --- a/compute_pppl.py +++ b/compute_pppl.py @@ -4,7 +4,7 @@ from transformers import AutoTokenizer, AutoModelForMaskedLM -def calculate_pseudo_perplexity(text, model_name): +def calculate_pseudo_perplexity(text, model_name, window_size): # Load the model and tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForMaskedLM.from_pretrained(model_name) @@ -19,11 +19,12 @@ def calculate_pseudo_perplexity(text, model_name): # Get tokenized words words = tokenizer.convert_ids_to_tokens(input_ids[0]) num_tokens = input_ids.size(1) + trg_pos = window_size // 2 - print(words) - print(num_tokens) - print(tokens) - print(word_ids) + #print(words) + #print(num_tokens) + #print(tokens) + #print(word_ids) pppl_per_word = [] @@ -34,21 +35,59 @@ def calculate_pseudo_perplexity(text, model_name): # Iteratively mask each token for i in range(1, num_tokens - 1): # Skip [CLS] and [SEP] tokens - masked_input_ids = input_ids.clone() - masked_input_ids[0, i] = tokenizer.mask_token_id # Mask the i-th token - print(i, masked_input_ids) - print(tokenizer.convert_ids_to_tokens(masked_input_ids[0])) + + # Define the sliding window range + start_idx = max(1, i - trg_pos) + end_idx = min(num_tokens - 1, i + trg_pos + 1) + + # Extract the window + window_ids = input_ids[0, start_idx:end_idx] + window_attention_mask = attention_mask[0, start_idx:end_idx] + #print("WID:", window_ids) + #print("WAM:", window_attention_mask) + + # Pad the window if it's smaller than window_size + if end_idx - start_idx < window_size: + padding_left = max(0, trg_pos - i + 1) + #print("pos:", trg_pos, "i:", i, "num_tokens:", num_tokens) + #print(padding_left) + padding_right = max(0, trg_pos - ((num_tokens - 2) - i)) + #print(padding_right) + window_ids = torch.cat( + [ + input_ids[0, :1].repeat(padding_left), # [CLS] token for padding + window_ids, + input_ids[0, -1:].repeat(padding_right), # [SEP] token for padding + ] + ) + window_attention_mask = torch.cat( + [ + attention_mask[0, :1].repeat(padding_left), + window_attention_mask, + attention_mask[0, -1:].repeat(padding_right), + ] + ) + + + #print("WID:", window_ids) + #print("WAM:", window_attention_mask) + + # Mask the middle token in the window + masked_window_ids = window_ids.clone() + masked_window_ids[trg_pos] = tokenizer.mask_token_id + print(i, masked_window_ids) + print(tokenizer.convert_ids_to_tokens(masked_window_ids)) # Predict masked token with torch.no_grad(): - outputs = model(input_ids=masked_input_ids, attention_mask=attention_mask) + outputs = model(input_ids=masked_window_ids.unsqueeze(0), attention_mask=window_attention_mask.unsqueeze(0)) logits = outputs.logits print(logits) # Compute log probability of the target token - target_token_id = input_ids[0, i] - log_probs = torch.log_softmax(logits[0, i], dim=-1) + target_token_id = window_ids[trg_pos] + log_probs = torch.log_softmax(logits[0, trg_pos], dim=-1) print("log prob mask:", log_probs[target_token_id].item()) trg_log_prob = log_probs[target_token_id].item() @@ -80,12 +119,16 @@ def calculate_pseudo_perplexity(text, model_name): def main(): model_name = "bert-base-multilingual-cased" + window_size =11 #example_sent = "Man träffc hochmüthige Leute an, die davor angesehen seyn wollten, sie wären dazu gesetzt, daß sie den Erdboden richten sollten, und dieser Wahn." #example_sent = "Hallo Welt!" - #example_sent= "Morgen gehen wir ins Kino." - example_sent = "Man träffc hochmüthige Leute an." + #example_sent= "Mofgen gehen win ims Kino und alle bekomen eine Tute Popcorn, weil si heute brv die Hausaufgben gemact haben." + example_sent= "Morgen gehen wir ins Kino und alle bekommen eine Tüte Popcorn, weil sie heute brav die Hausaufgaben gemacht haben." + #example_sent = "Ein grosser Baum steht vor dem Haus." + + #example_sent = "Man träffc hochmüthige Leute an , die davor angesehen seyn wollten ." #example_sent = "Good morning!" - words, pseudo_perplexity = calculate_pseudo_perplexity(example_sent, model_name) + words, pseudo_perplexity = calculate_pseudo_perplexity(example_sent, model_name, window_size) for word, pppl in zip(words, pseudo_perplexity): print(word, ":", pppl) -- GitLab