Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do I get the sentence embedding from GritLM/emb_m7_nodes16_fast? #36

Open
phartman-keysight opened this issue May 24, 2024 · 4 comments

Comments

@phartman-keysight
Copy link

Continuing from our conversation in #13 I just think it needed a new ticket at this point.

I am trying to finetune embeddings only so I took your(@Muennighoff 's) recommendation of using GritLM/emb_m7_nodes16_fast but I don't see the embedding for the entire sentence/article only the token embeddings. Am I misunderstanding something?

The standard grit model is both a generative and an encoder so the forward function is generative and encode is the embedding. So I use model.encode(input_tokens, instruction) which returns a vector with shape (4096,) which works great. Using the model you recommended there is no generative part so I assumed forward is the embedding function and there is no encode function, right? The issue I'm hitting is that when i run model(input_tokens) i get back a tuple for a 4096 embedding for each token as oppose to a single embedding for the entire article. Should I be doing pooling on these or is there some other function I should use to get the embedding?

Here is some example code

import numpy as np
from gritlm import GritLM

def cosine_similarity(vecA, vecB):
    dotP = np.dot(vecA, vecB)
    magA = np.sqrt(np.sum(np.square(vecA)))
    magB = np.sqrt(np.sum(np.square(vecB)))
    return dotP / (magA * magB)


model = GritLM("GritLM/GritLM-7B", torch_dtype="auto", mode='embedding')


def get_similarity_score(x,y):
    x_rep = model.encode(x, instruction="<|embed|>\n")
    y_rep = model.encode(y, instruction="<|embed|>\n")
    return cosine_similarity(x_rep, y_rep)

article_1 = "The dog eats dog food."
article_2 = "A canine consumes kibble."
article_3 = "The cat eats cat food."

print(get_similarity_score(article_1, article_2))
print(get_similarity_score(article_1, article_3))
print(get_similarity_score(article_2, article_3))


import torch
from transformers import AutoModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('GritLM/emb_m7_nodes16_fast', torch_dtype=torch.float32)
model = AutoModel.from_pretrained('GritLM/emb_m7_nodes16_fast', torch_dtype=torch.float32)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

model.eval()

def get_similarity_score_automodel(model, tokenizer, article1, article2):
    article1_tokens = tokenizer(article1, padding=True, truncation=True, return_tensors='pt', max_length=512).to(model.device)
    article2_tokens = tokenizer(article2, padding=True, truncation=True, return_tensors='pt', max_length=512).to(model.device)
    
    with torch.no_grad():
        article1_embeddings = model(**article1_tokens)
        article2_embeddings = model(**article2_tokens)

    # This part currently fails since the variables are BaseModelOutputWithPast objects
    # I should use the last_hidden_state attribute, right? 
    # When I do use last_hidden_state, I get a tensor of shape (1, #of tokens, 4096) as opposed to (1, 4096) or (4096,)
    similarity = cosine_similarity(article1_embeddings.cpu().numpy(), article2_embeddings.cpu().numpy())
    return similarity.item()


print(get_similarity_score_automodel(model, tokenizer, article_1, article_2))
print(get_similarity_score_automodel(model, tokenizer, article_1, article_3))
print(get_similarity_score_automodel(model, tokenizer, article_2, article_3))

Also the embeddings won't be the same since they are different models, but they result in similar similarity scores, right?

@Muennighoff
Copy link
Contributor

You should be able to just load it as follows:

model = GritLM("GritLM/emb_m7_nodes16_fast", torch_dtype="auto", mode='embedding')

and use it in the same way as GritLM/GritLM-7B. Else finetuning GritLM/GritLM-7B is probably just as fine - i don't know which one would perform better actually.

@phartman-keysight
Copy link
Author

I know standard grit generally uses mean pooling of the last hidden state for embeddings. I know it can use weighted mean, CLS, or last token instead. I know mean pooling of the token embeddings is a common way to generate sentence embeddings, but I've also seen fully connected "pooler" layers that are just one final dense layer that generates the embedding.
How did you decide to do mean pooling rather than basically have a "language head" and an "embedding head" that you just apply to the last hidden state for either output? (as opposed to language head and mean pooling?)
Do you think there would be performance improvements if someone were to apply that approach?

@Muennighoff
Copy link
Contributor

So you would apply that head over the seq len?
I.e. (batch size, seq len, hidden dim) -> (batch size, 1, hidden dim)?

The problem is that seq len may change depending on the sample. You'd likely have to pad it to always the same number of tokens & those padding tokens would then become part of the embedding as they're part of the matrix multiply which may hurt performance.

@phartman-keysight
Copy link
Author

Yeah, I see what you're saying. The reason the language model head works is because it only uses the last token embedding to generate the logits for the next token. I didn't realize that, so if i want something comparable to that I wouldn't make an "embedding head" I would just use the last token approach which you already support.

I understand now, thanks for the response.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants