From fe0661f403a3a1bff81db7ca6b13fdf38e2adaf7 Mon Sep 17 00:00:00 2001 From: Miltos Date: Sun, 14 May 2023 22:20:04 +0100 Subject: [PATCH] Revert "tsne vis: change the model & embeddings" This reverts commit edb4eb569981400d345e30c5a402765a2cfbd2e7. --- etc/compute_embeddings.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/etc/compute_embeddings.py b/etc/compute_embeddings.py index 950a8311..1e0c8da8 100644 --- a/etc/compute_embeddings.py +++ b/etc/compute_embeddings.py @@ -3,7 +3,6 @@ import numpy as np import torch -import torch.nn.functional as F import sklearn.manifold import transformers @@ -14,20 +13,13 @@ def parse_arguments(): parser.add_argument("json", default=False, help="the path the json containing all papers.") parser.add_argument("outpath", default=False, help="the target path of the visualizations papers.") parser.add_argument("--seed", default=0, help="The seed for TSNE.", type=int) - parser.add_argument("--model", default='sentence-transformers/all-MiniLM-L6-v2', help="Name of the HF model") - return parser.parse_args() -def mean_pooling(token_embeddings, attention_mask): - """ Mean Pooling, takes attention mask into account for correct averaging""" - input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() - return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) - if __name__ == "__main__": args = parse_arguments() - tokenizer = transformers.AutoTokenizer.from_pretrained(args.model) - model = transformers.AutoModel.from_pretrained(args.model) + tokenizer = transformers.AutoTokenizer.from_pretrained("deepset/sentence_bert") + model = transformers.AutoModel.from_pretrained("deepset/sentence_bert") model.eval() with open(args.json) as f: @@ -35,19 +27,16 @@ def mean_pooling(token_embeddings, attention_mask): print(f"Num papers: {len(data)}") - corpus = [] + all_embeddings = [] for paper_info in data: - corpus.append(tokenizer.sep_token.join([paper_info['title'], paper_info['abstract']])) - - encoded_corpus = tokenizer(corpus, padding=True, truncation=True, return_tensors='pt') - with torch.no_grad(): - hidden_states = model(**encoded_corpus).last_hidden_state - - corpus_embeddings = mean_pooling(hidden_states, encoded_corpus['attention_mask']) - corpus_embeddings = F.normalize(corpus_embeddings, p=2, dim=1) + with torch.no_grad(): + token_ids = torch.tensor([tokenizer.encode(paper_info["abstract"])][:512]) + hidden_states, _ = model(token_ids)[-2:] + all_embeddings.append(hidden_states.mean(0).mean(0).numpy()) np.random.seed(args.seed) - out = sklearn.manifold.TSNE(n_components=2, metric="cosine").fit_transform(corpus_embeddings) + all_embeddings = np.array(all_embeddings) + out = sklearn.manifold.TSNE(n_components=2, metric="cosine").fit_transform(all_embeddings) for i, paper_info in enumerate(data): paper_info['tsne_embedding'] = out[i].tolist()