Skip to content
Snippets Groups Projects
Commit 3577811a authored by Ubuntu's avatar Ubuntu
Browse files

:sparkles: implement sliding window

parent c9ada1ad
No related branches found
No related tags found
No related merge requests found
......@@ -4,7 +4,7 @@ from transformers import AutoTokenizer, AutoModelForMaskedLM
def calculate_pseudo_perplexity(text, model_name):
def calculate_pseudo_perplexity(text, model_name, window_size):
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
......@@ -19,11 +19,12 @@ def calculate_pseudo_perplexity(text, model_name):
# Get tokenized words
words = tokenizer.convert_ids_to_tokens(input_ids[0])
num_tokens = input_ids.size(1)
trg_pos = window_size // 2
print(words)
print(num_tokens)
print(tokens)
print(word_ids)
#print(words)
#print(num_tokens)
#print(tokens)
#print(word_ids)
pppl_per_word = []
......@@ -34,21 +35,59 @@ def calculate_pseudo_perplexity(text, model_name):
# Iteratively mask each token
for i in range(1, num_tokens - 1): # Skip [CLS] and [SEP] tokens
masked_input_ids = input_ids.clone()
masked_input_ids[0, i] = tokenizer.mask_token_id # Mask the i-th token
print(i, masked_input_ids)
print(tokenizer.convert_ids_to_tokens(masked_input_ids[0]))
# Define the sliding window range
start_idx = max(1, i - trg_pos)
end_idx = min(num_tokens - 1, i + trg_pos + 1)
# Extract the window
window_ids = input_ids[0, start_idx:end_idx]
window_attention_mask = attention_mask[0, start_idx:end_idx]
#print("WID:", window_ids)
#print("WAM:", window_attention_mask)
# Pad the window if it's smaller than window_size
if end_idx - start_idx < window_size:
padding_left = max(0, trg_pos - i + 1)
#print("pos:", trg_pos, "i:", i, "num_tokens:", num_tokens)
#print(padding_left)
padding_right = max(0, trg_pos - ((num_tokens - 2) - i))
#print(padding_right)
window_ids = torch.cat(
[
input_ids[0, :1].repeat(padding_left), # [CLS] token for padding
window_ids,
input_ids[0, -1:].repeat(padding_right), # [SEP] token for padding
]
)
window_attention_mask = torch.cat(
[
attention_mask[0, :1].repeat(padding_left),
window_attention_mask,
attention_mask[0, -1:].repeat(padding_right),
]
)
#print("WID:", window_ids)
#print("WAM:", window_attention_mask)
# Mask the middle token in the window
masked_window_ids = window_ids.clone()
masked_window_ids[trg_pos] = tokenizer.mask_token_id
print(i, masked_window_ids)
print(tokenizer.convert_ids_to_tokens(masked_window_ids))
# Predict masked token
with torch.no_grad():
outputs = model(input_ids=masked_input_ids, attention_mask=attention_mask)
outputs = model(input_ids=masked_window_ids.unsqueeze(0), attention_mask=window_attention_mask.unsqueeze(0))
logits = outputs.logits
print(logits)
# Compute log probability of the target token
target_token_id = input_ids[0, i]
log_probs = torch.log_softmax(logits[0, i], dim=-1)
target_token_id = window_ids[trg_pos]
log_probs = torch.log_softmax(logits[0, trg_pos], dim=-1)
print("log prob mask:", log_probs[target_token_id].item())
trg_log_prob = log_probs[target_token_id].item()
......@@ -80,12 +119,16 @@ def calculate_pseudo_perplexity(text, model_name):
def main():
model_name = "bert-base-multilingual-cased"
window_size =11
#example_sent = "Man träffc hochmüthige Leute an, die davor angesehen seyn wollten, sie wären dazu gesetzt, daß sie den Erdboden richten sollten, und dieser Wahn."
#example_sent = "Hallo Welt!"
#example_sent= "Morgen gehen wir ins Kino."
example_sent = "Man träffc hochmüthige Leute an."
#example_sent= "Mofgen gehen win ims Kino und alle bekomen eine Tute Popcorn, weil si heute brv die Hausaufgben gemact haben."
example_sent= "Morgen gehen wir ins Kino und alle bekommen eine Tüte Popcorn, weil sie heute brav die Hausaufgaben gemacht haben."
#example_sent = "Ein grosser Baum steht vor dem Haus."
#example_sent = "Man träffc hochmüthige Leute an , die davor angesehen seyn wollten ."
#example_sent = "Good morning!"
words, pseudo_perplexity = calculate_pseudo_perplexity(example_sent, model_name)
words, pseudo_perplexity = calculate_pseudo_perplexity(example_sent, model_name, window_size)
for word, pppl in zip(words, pseudo_perplexity):
print(word, ":", pppl)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment