Skip to content

Commit

Permalink
Adds annotations create workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Oct 8, 2024
1 parent 30f2827 commit 269132e
Show file tree
Hide file tree
Showing 24 changed files with 12,607 additions and 5,977 deletions.
112 changes: 66 additions & 46 deletions burr/tracking/server/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import os.path
import sys
from datetime import datetime
from typing import Any, Optional, Sequence, Tuple, Type, TypeVar

import aiofiles
Expand Down Expand Up @@ -91,18 +92,12 @@ async def update_annotation(
self,
annotation: AnnotationUpdate,
project_id: str,
partition_key: Optional[str],
app_id: str,
step_sequence_id: int,
annotation_id: int,
) -> AnnotationOut:
"""Updates an annotation -- annotation has annotation data, the other pointers are given in the parameters.
:param annotation: Annotation object to update
:param project_id: Project ID to associate with
:param partition_key: Partition key to associate with
:param app_id: App ID to associate with
:param step_sequence_id: Step sequence ID to associate with
:param annotation_id: Annotation ID to update. We include this as we may have multiple...
:return: Updated annotation
"""
Expand Down Expand Up @@ -262,6 +257,21 @@ class LocalBackend(BackendBase, AnnotationsBackendMixin):
To override the path, set a `burr_path` environment variable to the path you want to use.
"""

def __init__(self, path: str = DEFAULT_PATH):
self.path = path

def _get_annotation_path(self, project_id: str) -> str:
return os.path.join(self.path, project_id, "annotations.jsonl")

async def _load_project_annotations(self, project_id: str):
annotations_path = self._get_annotation_path(project_id)
annotations = []
if os.path.exists(annotations_path):
async with aiofiles.open(annotations_path) as f:
for line in await f.readlines():
annotations.append(AnnotationOut.parse_raw(line))
return annotations

async def create_annotation(
self,
annotation: AnnotationCreate,
Expand All @@ -270,18 +280,50 @@ async def create_annotation(
app_id: str,
step_sequence_id: int,
) -> AnnotationOut:
...
all_annotations = await self._load_project_annotations(project_id)
annotation_id = (
max([a.id for a in all_annotations], default=-1) + 1
) # get the ID, increment
annotation_out = AnnotationOut(
id=annotation_id,
project_id=project_id,
app_id=app_id,
partition_key=partition_key,
step_sequence_id=step_sequence_id,
created=datetime.now(),
updated=datetime.now(),
**annotation.dict(),
)
annotations_path = self._get_annotation_path(project_id)
async with aiofiles.open(annotations_path, "a") as f:
await f.write(annotation_out.json() + "\n")
return annotation_out

async def update_annotation(
self,
annotation: AnnotationUpdate,
project_id: str,
partition_key: Optional[str],
app_id: str,
step_sequence_id: int,
annotation_id: int,
) -> AnnotationOut:
...
all_annotations = await self._load_project_annotations(project_id)
annotation_out = None
for idx, a in enumerate(all_annotations):
if a.id == annotation_id:
annotation_out = a
all_annotations[idx] = annotation_out.copy(
update={**annotation.dict(), "updated": datetime.now()}
)
break
if annotation_out is None:
raise fastapi.HTTPException(
status_code=404,
detail=f"Annotation: {annotation_id} from project: {project_id} not found",
)
annotations_path = self._get_annotation_path(project_id)
async with aiofiles.open(annotations_path, "w") as f:
for a in all_annotations:
await f.write(a.json() + "\n")
return annotation_out

async def get_annotations(
self,
Expand All @@ -290,43 +332,21 @@ async def get_annotations(
app_id: Optional[str] = None,
step_sequence_id: Optional[int] = None,
) -> Sequence[AnnotationOut]:
def generate_fake_annotations(
project_id: str, app_id: Optional[str] = None, partition_key: Optional[str] = None
) -> Sequence[AnnotationOut]:
import random

import lorem

random.seed(42) # Setting a fixed seed for reproducibility
common_tags = ["urgent", "review", "important"]
annotations = []
for i in range(1, 6):
annotation = AnnotationOut(
id=i,
project_id=project_id,
app_id=app_id,
partition_key=partition_key,
step_sequence_id=i,
attributes=[],
span_id=None,
tags=[random.choice(common_tags) for _ in range(random.randint(1, 3))],
note=lorem.paragraph(),
thumbs_up_thumbs_down=random.choice([True, False, None]),
)
annotations.append(annotation)

return annotations

# Example usage
annotations = generate_fake_annotations(project_id="proj_123", app_id="app_1")

# Example usage
annotations = generate_fake_annotations(project_id="proj_123", app_id="app_1")
annotation_path = self._get_annotation_path(project_id)
if not os.path.exists(annotation_path):
return []
annotations = []
async with aiofiles.open(annotation_path) as f:
for line in await f.readlines():
parsed = AnnotationOut.parse_raw(line)
if (
(partition_key is None or parsed.partition_key == partition_key)
and (app_id is None or parsed.app_id == app_id)
and (step_sequence_id is None or parsed.step_sequence_id == step_sequence_id)
):
annotations.append(parsed)
return annotations

def __init__(self, path: str = DEFAULT_PATH):
self.path = path

async def list_projects(self, request: fastapi.Request) -> Sequence[schema.Project]:
out = []
if not os.path.exists(self.path):
Expand Down
32 changes: 20 additions & 12 deletions burr/tracking/server/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from starlette.templating import Jinja2Templates

from burr.tracking.server import schema
from burr.tracking.server.schema import (
from burr.tracking.server.schema import ( # AnnotationUpdate,
AnnotationCreate,
AnnotationOut,
AnnotationUpdate,
Expand Down Expand Up @@ -227,31 +227,39 @@ async def get_application_logs(
)


@app.post("/api/v0/{project_id}/{app_id}/{partition_key}/annotations", response_model=AnnotationOut)
@app.post(
"/api/v0/{project_id}/{app_id}/{partition_key}/{sequence_id}annotations",
response_model=AnnotationOut,
)
async def create_annotation(
request: Request, project_id: str, app_id: str, partition_key: str, annotation: AnnotationCreate
request: Request,
project_id: str,
app_id: str,
partition_key: str,
sequence_id: int,
annotation: AnnotationCreate,
):
if partition_key == SENTINEL_PARTITION_KEY:
partition_key = None
return await backend.create_annotation(request, project_id, app_id, partition_key, annotation)
return await backend.create_annotation(
annotation, project_id, partition_key, app_id, sequence_id
)


#
# # TODO -- take out these parameters cause we have the annotation ID
@app.put(
"/api/v0/{project_id}/{app_id}/{partition_key}/annotations/{annotation_id}",
"/api/v0/{project_id}/{annotation_id}/update_annotations",
response_model=AnnotationOut,
)
async def update_annotation(
request: Request,
project_id: str,
app_id: str,
partition_key: str,
annotation_id: int,
annotation: AnnotationUpdate,
):
if partition_key == SENTINEL_PARTITION_KEY:
partition_key = None
return await backend.update_annotation(
request, project_id, app_id, partition_key, annotation_id, annotation
annotation_id=annotation_id, annotation=annotation, project_id=project_id
)


Expand All @@ -261,14 +269,14 @@ async def get_annotations(
project_id: str,
app_id: Optional[str] = None,
partition_key: Optional[str] = None,
annotation_id: Optional[int] = None,
step_sequence_id: Optional[int] = None,
):
# Handle the sentinel value for partition_key
if partition_key == SENTINEL_PARTITION_KEY:
partition_key = None

# Logic to retrieve the annotations
return await backend.get_annotations(project_id, app_id, partition_key, annotation_id)
return await backend.get_annotations(project_id, partition_key, app_id, step_sequence_id)


@app.get("/api/v0/ready")
Expand Down
30 changes: 23 additions & 7 deletions burr/tracking/server/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import collections
import datetime
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

import pydantic
from pydantic import fields
Expand Down Expand Up @@ -185,25 +185,39 @@ class BackendSpec(pydantic.BaseModel):
supports_annotations: bool


class AnnotationDataPointer(pydantic.BaseModel):
type: Literal["state_field", "attribute"]
field_name: str # key of attribute/state field
span_id: Optional[
str
] # span_id if it's associated with a span, otherwise it's associated with an action


AllowedDataField = Literal["note", "ground_truth"]


class AnnotationObservation(pydantic.BaseModel):
data_fields: dict[str, Any]
thumbs_up_thumbs_down: Optional[bool]
data_pointers: List[AnnotationDataPointer]


class AnnotationCreate(pydantic.BaseModel):
"""Generic link for indexing job -- can be exposed in 'admin mode' in the UI"""

attributes: List[str] # list of associated attributes that are part of this
span_id: Optional[str]
step_name: str # Should be able to look it up but including for now
tags: List[str]
note: str
thumbs_up_thumbs_down: Optional[bool]
observations: List[AnnotationObservation]


class AnnotationUpdate(AnnotationCreate):
"""Generic link for indexing job -- can be exposed in 'admin mode' in the UI"""

# Identification for association
attributes: Optional[List[str]] = None # list of associated attributes that are part of this
span_id: Optional[str] = None
tags: Optional[List[str]] = []
note: Optional[str] = None
thumbs_up_thumbs_down: Optional[bool] = None
observations: List[AnnotationObservation]


class AnnotationOut(AnnotationCreate):
Expand All @@ -215,3 +229,5 @@ class AnnotationOut(AnnotationCreate):
app_id: str
partition_key: Optional[str]
step_sequence_id: int
created: datetime.datetime
updated: datetime.datetime
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "burr"
version = "0.30.4"
version = "0.31.0rc1"
dependencies = [] # yes, there are none
requires-python = ">=3.9"
authors = [
Expand Down
Loading

0 comments on commit 269132e

Please sign in to comment.