Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ORM for organization model #1914

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from sqlalchemy import engine_from_config, pool

from alembic import context
from letta.base import Base
from letta.config import LettaConfig
from letta.orm.base import Base
from letta.settings import settings

letta_config = LettaConfig.load()
Expand Down
10 changes: 6 additions & 4 deletions letta/agent_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
from tqdm import tqdm

from letta.agent_store.storage import StorageConnector, TableType
from letta.base import Base
from letta.config import LettaConfig
from letta.constants import MAX_EMBEDDING_DIM
from letta.metadata import EmbeddingConfigColumn, FileMetadataModel, ToolCallColumn
from letta.orm.base import Base

# from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
from letta.schemas.message import Message
Expand Down Expand Up @@ -509,10 +509,12 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None)

self.session_maker = db_context

# import sqlite3
# Need this in order to allow UUIDs to be stored successfully in the sqlite database
mattzh72 marked this conversation as resolved.
Show resolved Hide resolved
import sqlite3
import uuid

# sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le)
# sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))
sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le)
sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))

def insert_many(self, records, exists_ok=True, show_progress=False):
# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
Expand Down
3 changes: 0 additions & 3 deletions letta/base.py

This file was deleted.

2 changes: 1 addition & 1 deletion letta/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# Defaults
DEFAULT_USER_ID = "user-00000000"
DEFAULT_ORG_ID = "org-00000000"
DEFAULT_ORG_ID = "organization-f2b8978c-82d8-44b5-a82c-75ef93a10bc7"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this change?

DEFAULT_USER_NAME = "default"
DEFAULT_ORG_NAME = "default"

Expand Down
60 changes: 1 addition & 59 deletions letta/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
)
from sqlalchemy.sql import func

from letta.base import Base
from letta.config import LettaConfig
from letta.orm.base import Base
from letta.schemas.agent import AgentState
from letta.schemas.api_key import APIKey
from letta.schemas.block import Block, Human, Persona
Expand All @@ -34,7 +34,6 @@

# from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from letta.schemas.organization import Organization
from letta.schemas.source import Source
from letta.schemas.tool import Tool
from letta.schemas.user import User
Expand Down Expand Up @@ -174,21 +173,6 @@ def to_record(self) -> User:
return User(id=self.id, name=self.name, created_at=self.created_at, org_id=self.org_id)


class OrganizationModel(Base):
__tablename__ = "organizations"
__table_args__ = {"extend_existing": True}

id = Column(String, primary_key=True)
name = Column(String, nullable=False)
created_at = Column(DateTime(timezone=True))

def __repr__(self) -> str:
return f"<Organization(id='{self.id}' name='{self.name}')>"

def to_record(self) -> Organization:
return Organization(id=self.id, name=self.name, created_at=self.created_at)


# TODO: eventually store providers?
# class Provider(Base):
# __tablename__ = "providers"
Expand Down Expand Up @@ -551,14 +535,6 @@ def create_user(self, user: User):
session.add(UserModel(**vars(user)))
session.commit()

@enforce_types
def create_organization(self, organization: Organization):
with self.session_maker() as session:
if session.query(OrganizationModel).filter(OrganizationModel.id == organization.id).count() > 0:
raise ValueError(f"Organization with id {organization.id} already exists")
session.add(OrganizationModel(**vars(organization)))
session.commit()

@enforce_types
def create_block(self, block: Block):
with self.session_maker() as session:
Expand Down Expand Up @@ -698,16 +674,6 @@ def delete_user(self, user_id: str):

session.commit()

@enforce_types
def delete_organization(self, org_id: str):
with self.session_maker() as session:
# delete from organizations table
session.query(OrganizationModel).filter(OrganizationModel.id == org_id).delete()

# TODO: delete associated data

session.commit()

@enforce_types
def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50, user_id: Optional[str] = None) -> List[ToolModel]:
with self.session_maker() as session:
Expand Down Expand Up @@ -762,30 +728,6 @@ def get_user(self, user_id: str) -> Optional[User]:
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()

@enforce_types
def get_organization(self, org_id: str) -> Optional[Organization]:
with self.session_maker() as session:
results = session.query(OrganizationModel).filter(OrganizationModel.id == org_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()

@enforce_types
def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
with self.session_maker() as session:
query = session.query(OrganizationModel).order_by(desc(OrganizationModel.id))
if cursor:
query = query.filter(OrganizationModel.id < cursor)
results = query.limit(limit).all()
if not results:
return None, []
organization_records = [r.to_record() for r in results]
next_cursor = organization_records[-1].id
assert isinstance(next_cursor, str)

return next_cursor, organization_records

@enforce_types
def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
with self.session_maker() as session:
Expand Down
Empty file added letta/orm/__all__.py
Empty file.
Empty file added letta/orm/__init__.py
Empty file.
75 changes: 75 additions & 0 deletions letta/orm/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from datetime import datetime
from typing import Optional
from uuid import UUID

from sqlalchemy import UUID as SQLUUID
from sqlalchemy import Boolean, DateTime, func, text
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
declarative_mixin,
declared_attr,
mapped_column,
)


class Base(DeclarativeBase):
"""absolute base for sqlalchemy classes"""


@declarative_mixin
class CommonSqlalchemyMetaMixins(Base):
__abstract__ = True

created_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), server_default=func.now())
updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), server_default=func.now(), server_onupdate=func.now())
is_deleted: Mapped[bool] = mapped_column(Boolean, server_default=text("FALSE"))

@declared_attr
def _created_by_id(cls):
return cls._user_by_id()

@declared_attr
def _last_updated_by_id(cls):
return cls._user_by_id()

@classmethod
def _user_by_id(cls):
"""a flexible non-constrained record of a user.
This way users can get added, deleted etc without history freaking out
"""
return mapped_column(SQLUUID(), nullable=True)

@property
def last_updated_by_id(self) -> Optional[str]:
return self._user_id_getter("last_updated")

@last_updated_by_id.setter
def last_updated_by_id(self, value: str) -> None:
self._user_id_setter("last_updated", value)

@property
def created_by_id(self) -> Optional[str]:
return self._user_id_getter("created")

@created_by_id.setter
def created_by_id(self, value: str) -> None:
self._user_id_setter("created", value)

def _user_id_getter(self, prop: str) -> Optional[str]:
"""returns the user id for the specified property"""
full_prop = f"_{prop}_by_id"
prop_value = getattr(self, full_prop, None)
if not prop_value:
return
return f"user-{prop_value}"

def _user_id_setter(self, prop: str, value: str) -> None:
"""returns the user id for the specified property"""
full_prop = f"_{prop}_by_id"
if not value:
setattr(self, full_prop, None)
return
prefix, id_ = value.split("-", 1)
assert prefix == "user", f"{prefix} is not a valid id prefix for a user id"
setattr(self, full_prop, UUID(id_))
8 changes: 8 additions & 0 deletions letta/orm/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from enum import Enum


class ToolSourceType(str, Enum):
"""Defines what a tool was derived from"""

python = "python"
json = "json"
2 changes: 2 additions & 0 deletions letta/orm/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class NoResultFound(Exception):
"""A record or records cannot be found given the provided search params"""
40 changes: 40 additions & 0 deletions letta/orm/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Optional, Type
from uuid import UUID

from letta.orm.base import Base


class MalformedIdError(Exception):
pass


def _relation_getter(instance: "Base", prop: str) -> Optional[str]:
prefix = prop.replace("_", "")
formatted_prop = f"_{prop}_id"
try:
uuid_ = getattr(instance, formatted_prop)
return f"{prefix}-{uuid_}"
except AttributeError:
return None


def _relation_setter(instance: Type["Base"], prop: str, value: str) -> None:
formatted_prop = f"_{prop}_id"
prefix = prop.replace("_", "")
if not value:
setattr(instance, formatted_prop, None)
return
try:
found_prefix, id_ = value.split("-", 1)
except ValueError as e:
raise MalformedIdError(f"{value} is not a valid ID.") from e
assert (
# TODO: should be able to get this from the Mapped typing, not sure how though
# prefix = getattr(?, "prefix")
found_prefix
== prefix
), f"{found_prefix} is not a valid id prefix, expecting {prefix}"
try:
setattr(instance, formatted_prop, UUID(id_))
except ValueError as e:
raise MalformedIdError("Hash segment of {value} is not a valid UUID") from e
35 changes: 35 additions & 0 deletions letta/orm/organization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import TYPE_CHECKING

from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Mapped, mapped_column

from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.organization import Organization as PydanticOrganization

if TYPE_CHECKING:
from sqlalchemy.orm import Session


class Organization(SqlalchemyBase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this not use the OrganizationMixin? Also, do we need the OrganizationMixin if we only have one ORM model that needs it?

"""The highest level of the object tree. All Entities belong to one and only one Organization."""

__tablename__ = "organizations"
__pydantic_model__ = PydanticOrganization

name: Mapped[str] = mapped_column(doc="The display name of the organization.")

# TODO: Map these relationships later when we actually make these models
# below is just a suggestion
# users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")
# agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
# sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan")
# tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
# documents: Mapped[List["Document"]] = relationship("Document", back_populates="organization", cascade="all, delete-orphan")

@classmethod
def default(cls, db_session: "Session") -> "Organization":
"""Get the default org, or create it if it doesn't exist."""
try:
return db_session.query(cls).filter(cls.name == "Default Organization").one()
except NoResultFound:
return cls(name="Default Organization").create(db_session)
Loading
Loading