Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add support for initializing vecs client with custom schema #63

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ repos:
hooks:
- id: autoflake
args: ['--in-place', '--remove-all-unused-imports']
language_version: python3.8

- repo: https://github.com/ambv/black
rev: 22.10.0
hooks:
- id: black
language_version: python3.9
language_version: python3.8
Comment on lines 28 to +29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what was the rationale for this change?

2 changes: 2 additions & 0 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import vecs

PYTEST_DB = "postgresql://postgres:password@localhost:5611/vecs_db"
PYTEST_SCHEMA = "test_schema"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the best way to test that escaping is done correctly in all places is to use a crazy schema name.

I tested basic operations with the schema name "esCape Me!" and the test suite fails.

To reproduce that, try:

foo: Collection = client.get_or_create_collection(name="foo", schema="esCape Me!", dimension=5)

and you'll get

sqlalchemy.exc.ProgrammingError: (psycopg2.errors.InvalidSchemaName) schema ""esCape Me!"" does not exist



@pytest.fixture(scope="session")
Expand Down Expand Up @@ -95,6 +96,7 @@ def clean_db(maybe_start_pg: None) -> Generator[str, None, None]:
eng = create_engine(PYTEST_DB)
with eng.begin() as connection:
connection.execute(text("drop schema if exists vecs cascade;"))
connection.execute(text(f"drop schema if exists {PYTEST_SCHEMA} cascade;"))
yield PYTEST_DB
eng.dispose()

Expand Down
6 changes: 6 additions & 0 deletions src/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,17 @@ def test_get_collection(client: vecs.Client) -> None:


def test_list_collections(client: vecs.Client) -> None:
"""
Test list_collections returns appropriate results for default schema (vecs) and custom schema
"""
assert len(client.list_collections()) == 0
client.get_or_create_collection(name="docs", dimension=384)
client.get_or_create_collection(name="books", dimension=1586)
client.get_or_create_collection(name="movies", schema="test_schema", dimension=384)
collections = client.list_collections()
collections_test_schema = client.list_collections(schema="test_schema")
assert len(collections) == 2
assert len(collections_test_schema) == 1


def test_delete_collection(client: vecs.Client) -> None:
Expand Down
66 changes: 66 additions & 0 deletions src/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,3 +815,69 @@ def test_hnsw_unavailable_error(client: vecs.Client) -> None:
bar = client.get_or_create_collection(name="bar", dimension=dim)
with pytest.raises(ArgError):
bar.create_index(method=IndexMethod.hnsw)


def test_get_or_create_with_schema(client: vecs.Client):
"""
Test that get_or_create_collection works when specifying custom schema
"""

dim = 384

collection_1 = client.get_or_create_collection(
name="collection_1", schema="test_schema", dimension=dim
)
collection_2 = client.get_or_create_collection(
name="collection_1", schema="test_schema", dimension=dim
)

assert collection_1.schema == "test_schema"
assert collection_1.schema == collection_2.schema
assert collection_1.name == collection_2.name


def test_upsert_with_schema(client: vecs.Client) -> None:
n_records = 100
dim = 384

movies1 = client.get_or_create_collection(
name="ping", schema="test_schema", dimension=dim
)
movies2 = client.get_or_create_collection(name="ping", schema="vecs", dimension=dim)

# collection initially empty
assert len(movies1) == 0
assert len(movies2) == 0

records = [
(
f"vec{ix}",
vec,
{
"genre": random.choice(["action", "rom-com", "drama"]),
"year": int(50 * random.random()) + 1970,
},
)
for ix, vec in enumerate(np.random.random((n_records, dim)))
]

# insert works
movies1.upsert(records)
assert len(movies1) == n_records

movies2.upsert(records)
assert len(movies2) == n_records

# upserting overwrites
new_record = ("vec0", np.zeros(384), {})
movies1.upsert([new_record])
db_record = movies1["vec0"]
db_record[0] == new_record[0]
db_record[1] == new_record[1]
db_record[2] == new_record[2]

movies2.upsert([new_record])
db_record = movies2["vec0"]
db_record[0] == new_record[0]
db_record[1] == new_record[1]
db_record[2] == new_record[2]
7 changes: 5 additions & 2 deletions src/vecs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,8 @@


def create_client(connection_string: str) -> Client:
"""Creates a client from a Postgres connection string"""
return Client(connection_string)
"""
Creates a client from a Postgres connection string and optional schema.
Defaults to `vecs` schema.
Comment on lines +28 to +29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

 and optional schema. Defaults to `vecs` schema.

can be removed

"""
return Client(connection_string=connection_string)
17 changes: 9 additions & 8 deletions src/vecs/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import TYPE_CHECKING, List, Optional

from deprecated import deprecated
from sqlalchemy import MetaData, create_engine, text
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker

from vecs.adapter import Adapter
Expand Down Expand Up @@ -53,12 +53,10 @@ def __init__(self, connection_string: str):

Args:
connection_string (str): A string representing the database connection information.

Returns:
None
"""
self.engine = create_engine(connection_string)
self.meta = MetaData(schema="vecs")
self.Session = sessionmaker(self.engine)

with self.Session() as sess:
Expand All @@ -84,6 +82,7 @@ def get_or_create_collection(
self,
name: str,
*,
schema: str = "vecs",
dimension: Optional[int] = None,
adapter: Optional[Adapter] = None,
) -> Collection:
Expand Down Expand Up @@ -113,6 +112,7 @@ def get_or_create_collection(
dimension=dimension or adapter_dimension, # type: ignore
client=self,
adapter=adapter,
schema=schema,
)

return collection._create_if_not_exists()
Expand Down Expand Up @@ -182,32 +182,33 @@ def get_collection(self, name: str) -> Collection:
self,
)

def list_collections(self) -> List["Collection"]:
def list_collections(self, *, schema: str = "vecs") -> List["Collection"]:
"""
List all vector collections.
List all vector collections by database schema.

Returns:
list[Collection]: A list of all collections.
"""
from vecs.collection import Collection

return Collection._list_collections(self)
return Collection._list_collections(self, schema)

def delete_collection(self, name: str) -> None:
def delete_collection(self, name: str, *, schema: str = "vecs") -> None:
"""
Delete a vector collection.

If no collection with requested name exists, does nothing.

Args:
name (str): The name of the collection.
schema (str): Optional, the database schema. Defaults to `vecs`.

Returns:
None
"""
from vecs.collection import Collection

Collection(name, -1, self)._drop()
Collection(name, -1, self, schema=schema)._drop()
return

def disconnect(self) -> None:
Expand Down
38 changes: 26 additions & 12 deletions src/vecs/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(
dimension: int,
client: Client,
adapter: Optional[Adapter] = None,
schema: Optional[str] = "vecs",
):
"""
Initializes a new instance of the `Collection` class.
Expand All @@ -174,7 +175,12 @@ def __init__(
self.client = client
self.name = name
self.dimension = dimension
self.table = build_table(name, client.meta, dimension)
self._schema = schema
self.schema = self.client.engine.dialect.identifier_preparer.quote_schema(
self._schema
)
self.meta = MetaData(schema=self.schema)
self.table = build_table(name, self.meta, dimension)
self._index: Optional[str] = None
self.adapter = adapter or Adapter(steps=[NoOp(dimension=dimension)])

Expand All @@ -195,6 +201,10 @@ def __init__(
"Dimensions reported by adapter, dimension, and collection do not match"
)

with self.client.Session() as sess:
with sess.begin():
sess.execute(text(f"create schema if not exists {self.schema};"))
olirice marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self):
"""
Returns a string representation of the `Collection` instance.
Expand Down Expand Up @@ -235,7 +245,7 @@ def _create_if_not_exists(self):
join pg_attribute pa
on pc.oid = pa.attrelid
where
pc.relnamespace = 'vecs'::regnamespace
pc.relnamespace = '{self.schema}'::regnamespace
and pc.relkind = 'r'
and pa.attname = 'vec'
and not pc.relname ^@ '_'
Expand Down Expand Up @@ -285,11 +295,12 @@ def _create(self):

unique_string = str(uuid.uuid4()).replace("-", "_")[0:7]
with self.client.Session() as sess:
sess.execute(text(f"create schema if not exists {self.schema};"))
sess.execute(
text(
f"""
create index ix_meta_{unique_string}
on vecs."{self.table.name}"
on {self.schema}."{self.table.name}"
using gin ( metadata jsonb_path_ops )
"""
)
Expand Down Expand Up @@ -562,21 +573,22 @@ def query(
return sess.execute(stmt).fetchall() or []

@classmethod
def _list_collections(cls, client: "Client") -> List["Collection"]:
def _list_collections(cls, client: "Client", schema: str) -> List["Collection"]:
"""
PRIVATE

Retrieves all collections from the database.

Args:
client (Client): The database client.
schema (str): The database schema to query.

Returns:
List[Collection]: A list of all existing collections.
List[Collection]: A list of all existing collections within the specified schema.
"""

query = text(
"""
f"""
select
relname as table_name,
atttypmod as embedding_dim
Expand All @@ -585,7 +597,7 @@ def _list_collections(cls, client: "Client") -> List["Collection"]:
join pg_attribute pa
on pc.oid = pa.attrelid
where
pc.relnamespace = 'vecs'::regnamespace
pc.relnamespace = '{schema}'::regnamespace
and pc.relkind = 'r'
and pa.attname = 'vec'
and not pc.relname ^@ '_'
Expand Down Expand Up @@ -636,13 +648,13 @@ def index(self) -> Optional[str]:

if self._index is None:
query = text(
"""
f"""
select
relname as table_name
from
pg_class pc
where
pc.relnamespace = 'vecs'::regnamespace
pc.relnamespace = '{self.schema}'::regnamespace
and relname ilike 'ix_vector%'
and pc.relkind = 'i'
"""
Expand Down Expand Up @@ -760,7 +772,9 @@ def create_index(
with sess.begin():
if self.index is not None:
if replace:
sess.execute(text(f'drop index vecs."{self.index}";'))
sess.execute(
text(f'drop index "{self.schema}"."{self.index}";')
)
self._index = None
else:
raise ArgError("replace is set to False but an index exists")
Expand All @@ -787,7 +801,7 @@ def create_index(
text(
f"""
create index ix_{ops}_ivfflat_nl{n_lists}_{unique_string}
on vecs."{self.table.name}"
on {self.schema}."{self.table.name}"
using ivfflat (vec {ops}) with (lists={n_lists})
"""
)
Expand All @@ -806,7 +820,7 @@ def create_index(
text(
f"""
create index ix_{ops}_hnsw_m{m}_efc{ef_construction}_{unique_string}
on vecs."{self.table.name}"
on {self.schema}."{self.table.name}"
using hnsw (vec {ops}) WITH (m={m}, ef_construction={ef_construction});
"""
)
Expand Down
Loading