Skip to content

Commit

Permalink
Merge pull request #79 from duggalsu/benchmark_es
Browse files Browse the repository at this point in the history
Add ElasticSearch benchmarking
  • Loading branch information
duggalsu authored Feb 14, 2024
2 parents ad94ad7 + 1e6b470 commit 03915d3
Show file tree
Hide file tree
Showing 9 changed files with 342 additions and 36 deletions.
3 changes: 3 additions & 0 deletions src/api/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ credentials.json

# pyinstrument
*.json

# locust
*.csv
77 changes: 48 additions & 29 deletions src/api/core/operators/vid_vec_rep_resnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

import sys
import traceback

def initialize(param):
print("Installing packages for vid_vec_rep_resnet")
Expand Down Expand Up @@ -110,8 +111,8 @@ def check_constraints(self):
check if video is too big/unsupported.
return fail=1, set appropriate error
"""
if self.fsize > 20:
return False, "file size larger than 20 MB not supported"
if self.fsize > 10:
return False, "file size larger than 10 MB not supported"
# TODO : based on data statistics, how long it takes to process a video decide thresholds based on w x h, frames
return True, None

Expand Down Expand Up @@ -163,42 +164,60 @@ def extract_frames(self, v):
continue
else:
if i % self.sampling_rate == 0:
images.append(Image.fromarray(image))
# images.append(Image.fromarray(image))
yield [Image.fromarray(image)]
# print("extracted frames")
return images
# print("len(images):", len(images))
# print("sys.getsizeof(images[0])", sys.getsizeof(images[0]))
# print("sys.getsizeof(images)", sys.getsizeof(images))
# return images

def extract_features(self, images, batch_size=1):
try:
dset = ImageListDataset(images)
dloader = data.DataLoader(dset, batch_size=batch_size, shuffle=False)
res = []
feature_layer = self.model._modules.get("avgpool")

def hook(m, i, o):
feature_data = o.data.reshape((512, batch_size))
embedding.copy_(feature_data)

self.model.eval()
for i, image in enumerate(dloader):
embedding = torch.zeros(512, batch_size)
h = feature_layer.register_forward_hook(hook)
self.model(image)
h.remove()
res.append(embedding.numpy())
res = np.hstack(res)
assert res.shape == (512, len(images))
return res

except Exception:
print(logging.traceback.format_exc())
res = []
image_count = 0
for img in images:
# print("image_count: ", image_count)
image_count += 1
try:
dset = ImageListDataset(img)
dloader = data.DataLoader(dset, batch_size=batch_size, shuffle=False)
feature_layer = self.model._modules.get("avgpool")

def hook(m, i, o):
feature_data = o.data.reshape((512, batch_size))
embedding.copy_(feature_data)

self.model.eval()
for i, image in enumerate(dloader):
embedding = torch.zeros(512, batch_size)
h = feature_layer.register_forward_hook(hook)
self.model(image)
h.remove()
res.append(embedding.numpy())
# print("len(res)", len(res))
# res = np.hstack(res)
# print("res.shape:", res.shape)
# print("sys.getsizeof(res)", sys.getsizeof(res))
# assert res.shape == (512, len(images))
# return res

except Exception:
print(traceback.format_exc())

print("len(res)", len(res))
res = np.hstack(res)
print("res.shape:", res.shape)
print("sys.getsizeof(res)", sys.getsizeof(res))
assert res.shape == (512, image_count)
return res

def find_keyframes(self, feature_matrix):
# print("finding keyframes")
Q, R, P = qr(feature_matrix, pivoting=True, overwrite_a=False)
# Q is the orthogonal matrix that is an approximation of the featue matrix
# P is a pivot matrix containing indices of the original (feature matrix) image vectors that have the largest vector norms
# We select the first n indices from P to get the n keyframes
print(P)
# print(P)
idx = P[: self.n_keyframes]
# print("found keyframes")
return idx
Expand Down
2 changes: 2 additions & 0 deletions src/api/core/store/es_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,15 @@ def find(self, index_name, vec):
if type(vec) == np.ndarray:
vec = vec.tolist()

calculation = ""
if index_name == self.indices["text"]:
calculation = "1 / (1 + l2norm(params.query_vector, 'text_vec'))"
elif index_name == self.indices["image"]:
calculation = "1 / (1 + l2norm(params.query_vector, 'image_vec'))"
elif index_name == self.indices["video"]:
calculation = "1 / (1 + l2norm(params.query_vector, 'vid_vec'))"

print("calculation:", calculation)
q = {
"size": 10, # maximum number of hits returned by the query
"query": {
Expand Down
5 changes: 3 additions & 2 deletions src/api/requirements.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
flask==2.3.2
flask_cors==3.0.9
flask_cors==3.0.10
Pillow==10.2.0
elasticsearch==8.11.1
wget==3.2
Expand All @@ -12,4 +12,5 @@ dacite==1.8.1
memray==1.11.0 # dev
pyinstrument==4.6.2
numpy==1.26.3
requests==2.31.0
requests==2.31.0
locust==2.23.1
53 changes: 48 additions & 5 deletions src/api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,21 @@
#
blinker==1.7.0
# via flask
brotli==1.1.0
# via geventhttpclient
certifi==2024.2.2
# via
# elastic-transport
# geventhttpclient
# requests
charset-normalizer==3.3.2
# via requests
click==8.1.7
# via
# -r requirements.in
# flask
configargparse==1.7
# via locust
dacite==1.8.1
# via -r requirements.in
elastic-transport==8.12.0
Expand All @@ -26,8 +31,22 @@ flask==2.3.2
# via
# -r requirements.in
# flask-cors
flask-cors==3.0.9
# via -r requirements.in
# flask-login
# locust
flask-cors==3.0.10
# via
# -r requirements.in
# locust
flask-login==0.6.3
# via locust
gevent==23.9.1
# via
# geventhttpclient
# locust
geventhttpclient==2.0.11
# via locust
greenlet==3.0.3
# via gevent
idna==3.6
# via requests
iniconfig==2.0.0
Expand All @@ -40,6 +59,8 @@ jinja2==3.1.3
# memray
linkify-it-py==2.0.3
# via markdown-it-py
locust==2.23.1
# via -r requirements.in
markdown-it-py[linkify,plugins]==3.0.0
# via
# mdit-py-plugins
Expand All @@ -55,6 +76,8 @@ mdurl==0.1.2
# via markdown-it-py
memray==1.11.0
# via -r requirements.in
msgpack==1.0.7
# via locust
numpy==1.26.3
# via -r requirements.in
packaging==23.2
Expand All @@ -65,6 +88,8 @@ pillow==10.2.0
# via -r requirements.in
pluggy==1.4.0
# via pytest
psutil==5.9.8
# via locust
pygments==2.17.2
# via rich
pyinstrument==4.6.2
Expand All @@ -75,14 +100,22 @@ python-dotenv==1.0.0
# via -r requirements.in
pyyaml==6.0.1
# via -r requirements.in
pyzmq==25.1.2
# via locust
requests==2.31.0
# via -r requirements.in
# via
# -r requirements.in
# locust
rich==13.7.0
# via
# memray
# textual
roundrobin==0.0.4
# via locust
six==1.16.0
# via flask-cors
# via
# flask-cors
# geventhttpclient
textual==0.50.1
# via memray
typing-extensions==4.9.0
Expand All @@ -94,6 +127,16 @@ urllib3==2.2.0
# elastic-transport
# requests
werkzeug==3.0.1
# via flask
# via
# flask
# flask-login
# locust
wget==3.2
# via -r requirements.in
zope-event==5.0
# via gevent
zope-interface==6.1
# via gevent

# The following packages are considered to be unsafe in a requirements file:
# setuptools
100 changes: 100 additions & 0 deletions src/api/test_video_es_vec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import unittest
from unittest.case import skip
import pprint
from core.store.es_vec import ES
from core.config import StoreConfig, StoreParameters
from core.models.media import MediaType
from core.operators import vid_vec_rep_resnet
from datetime import datetime

pp = pprint.PrettyPrinter(indent=4)
'''
# Get indexing stats
curl -X GET "http://es:9200/_stats/indexing?pretty"
# Check how many documents have been indexed
curl -X GET "http://es:9200/_cat/indices?v"
# Delete all the documents in an index
curl -X POST "http://es:9200/video/_delete_by_query" -H 'Content-Type: application/json' -d'{"query":{"match_all":{}}}'
'''


class TestVideoES(unittest.TestCase):

@classmethod
def setUpClass(cls) -> None:
param_dict = {
"host_name": "es",
"text_index_name": "text",
"image_index_name": "image",
"video_index_name": "video",
}
cls.param = StoreConfig(
label="test",
type="es",
parameters=StoreParameters(
host_name=param_dict["host_name"],
image_index_name=param_dict["image_index_name"],
text_index_name=param_dict["text_index_name"],
video_index_name=param_dict["video_index_name"],
)
)

@classmethod
def tearDownClass(cls) -> None:
print("TEARING DOWN CLASS")

@staticmethod
def generate_document(post_id: str, representation: any):
base_doc = {
"e_kosh_id": "",
"dataset": post_id,
"metadata": None,
"date_added": datetime.now().isoformat(),
}

def generator_doc():
for vector in representation:
base_doc["_index"] = "video"
base_doc["vid_vec"] = vector["vid_vec"]
base_doc["is_avg"] = vector["is_avg"]
base_doc["duration"] = vector["duration"]
base_doc["n_keyframes"] = vector["n_keyframes"]
yield base_doc

return generator_doc

# @skip
def test_1_store_video_vector(self):
es = ES(self.param)
es.connect()

# generate video embedding
vid_vec_rep_resnet.initialize(param=None)
file_name = "sample-cat-video.mp4"
video = {"path": r"core/operators/sample_data/sample-cat-video.mp4"}
embedding = vid_vec_rep_resnet.run(video)
doc = self.generate_document(file_name, embedding)

media_type = MediaType.VIDEO
result = es.store(media_type, doc)
print("result:", result)

self.assertEqual(result["message"], "multiple media stored")

# @skip
def test_2_search_video_vector(self):
es = ES(self.param)
es.connect()
es.optionally_create_index()

# generate video embedding
vid_vec_rep_resnet.initialize(param=None)
file_name = "sample-cat-video.mp4"
video = {"path": r"core/operators/sample_data/sample-cat-video.mp4"}
embedding = vid_vec_rep_resnet.run(video)
average_vector = next(embedding)

search_result = es.find("video", average_vector.get('vid_vec'))
print("SEARCH RESULTS \n : ")
pp.pprint(search_result)
self.assertEqual(search_result[0].get('dataset'), file_name)
Loading

0 comments on commit 03915d3

Please sign in to comment.