diff --git a/literalai/instrumentation/openai.py b/literalai/instrumentation/openai.py index df1894e..b388eb4 100644 --- a/literalai/instrumentation/openai.py +++ b/literalai/instrumentation/openai.py @@ -130,6 +130,7 @@ def init_generation(generation_type: "GenerationType", kwargs): tools=tools, settings=settings, messages=messages, + tags=kwargs.get("literal_tags"), ) elif generation_type == GenerationType.COMPLETION: @@ -156,6 +157,7 @@ def init_generation(generation_type: "GenerationType", kwargs): model=model, settings=settings, prompt=kwargs.get("prompt"), + tags=kwargs.get("literal_tags"), ) def update_step_after( diff --git a/literalai/wrappers.py b/literalai/wrappers.py index 5dff0da..4cce786 100644 --- a/literalai/wrappers.py +++ b/literalai/wrappers.py @@ -23,6 +23,21 @@ class AfterContext(TypedDict): start: float +def remove_literal_args(kargs): + '''Remove argument prefixed with "literal_" from kwargs and return them in a separate dict''' + largs = {} + for key in list(kargs.keys()): + if key.startswith("literal_"): + value = kargs.pop(key) + largs[key] = value + return largs + +def restore_literal_args(kargs, largs): + '''Reverse the effect of remove_literal_args by merging the literal arguments into kwargs''' + for key in list(largs.keys()): + kargs[key] = largs[key] + + def sync_wrapper(before_func=None, after_func=None): def decorator(original_func): @wraps(original_func) @@ -31,6 +46,8 @@ def wrapped(*args, **kwargs): # If a before_func is provided, call it with the shared context. if before_func: before_func(context, *args, **kwargs) + # Remove literal arguments before calling the original function + literal_args = remove_literal_args(kwargs) context["start"] = time.time() try: result = original_func(*args, **kwargs) @@ -44,6 +61,7 @@ def wrapped(*args, **kwargs): raise e # If an after_func is provided, call it with the result and the shared context. if after_func: + restore_literal_args(kwargs, literal_args) result = after_func(result, context, *args, **kwargs) return result @@ -62,6 +80,8 @@ async def wrapped(*args, **kwargs): if before_func: await before_func(context, *args, **kwargs) + # Remove literal arguments before calling the original function + literal_args = remove_literal_args(kwargs) context["start"] = time.time() try: result = await original_func(*args, **kwargs) @@ -76,6 +96,7 @@ async def wrapped(*args, **kwargs): # If an after_func is provided, call it with the result and the shared context. if after_func: + restore_literal_args(kwargs, literal_args) result = await after_func(result, context, *args, **kwargs) return result