-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_calculate_embeddings_for_messages.py
122 lines (107 loc) · 3.87 KB
/
run_calculate_embeddings_for_messages.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from urllib.parse import urlparse
from pprint import pprint
import pandas as pd
from pprint import pprint
from modules.query_resolver import query_resolver
import psycopg2
import psycopg2.extras
from pgvector.psycopg2 import register_vector
import numpy as np
import os
from decouple import AutoConfig
current_directory = os.getcwd()
parent_directory = os.path.dirname(current_directory)
config = AutoConfig(search_path=current_directory)
url = urlparse(config("db_url"))
connection = psycopg2.connect(
host=url.hostname,
port=url.port,
database=url.path[1:],
user=url.username,
password=url.password
)
cursor = connection.cursor(cursor_factory = psycopg2.extras.RealDictCursor)
register_vector(connection)
# AI
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
# #Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# # Load model from HuggingFace Hub
embedding_model='sentence-transformers/all-MiniLM-L6-v2'
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
# END AI
cursor.execute("select count(*) as count from messages_t")
query_results = cursor.fetchall()
result_df = pd.DataFrame(query_results)
num_of_messages = result_df.iloc[0]["count"]
for i in range(0, num_of_messages, 100):
query = f"""
select
*
from
messages_t
limit 100
offset {i}
"""
cursor.execute(query)
query_results = cursor.fetchall()
result_df = pd.DataFrame(query_results)
print(len(result_df))
message_ids = []
sentences = []
for msg_content in result_df.itertuples():
message_ids.append(msg_content.id)
sentences.append(msg_content.msg_content)
# print(msg_content.msg_content)
# msg_content = row.iloc[0]["msg_content"]
# print(msg_content)
# START AI
# Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)
# Perform pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
# END AI
print("Sentence embeddings:")
print(len(sentence_embeddings))
print(len(sentence_embeddings[0]))
print(len(sentence_embeddings[1]))
query = """
INSERT INTO messages_vectors_bert_t (
message_id ,
embedding_model ,
embedding
)
VALUES (
%s, %s, %s
);
"""
for i in range(len(message_ids)):
print("INSERTING")
cursor.execute(query, (message_ids[i], embedding_model, np.array( sentence_embeddings[i]) ))
connection.commit()
# # Sentences we want sentence embeddings for
# sentences = ['This is an example sentence', 'Each sentence is converted']
# # Tokenize sentences
# encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# # Compute token embeddings
# with torch.no_grad():
# model_output = model(**encoded_input)
# # Perform pooling
# sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
# # Normalize embeddings
# sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
# print("Sentence embeddings:")
# print(sentence_embeddings)
# print(len(sentence_embeddings[0]))
# print(len(sentence_embeddings[1]))