Skip to content

Commit

Permalink
feat: added dataset recall testing API
Browse files Browse the repository at this point in the history
  • Loading branch information
gubinjie committed Oct 14, 2024
1 parent 42b02b3 commit abdf444
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 2 deletions.
3 changes: 1 addition & 2 deletions api/controllers/service_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
bp = Blueprint("service_api", __name__, url_prefix="/v1")
api = ExternalApi(bp)


from . import index
from .app import app, audio, completion, conversation, file, message, workflow
from .dataset import dataset, document, segment
from .dataset import dataset, document, hit_testing, segment
85 changes: 85 additions & 0 deletions api/controllers/service_api/dataset/hit_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import logging

from flask_login import current_user
from flask_restful import marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound

import services.dataset_service
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import DatasetNotInitializedError
from controllers.service_api import api
from controllers.service_api.wraps import DatasetApiResource
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService


class HitTestingApi(DatasetApiResource):
def post(self, tenant_id, dataset_id):
dataset_id_str = str(dataset_id)
logging.error(f"{dataset_id_str}")

dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")

try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))

parser = reqparse.RequestParser()

parser.add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args()

HitTestingService.hit_testing_args_check(args)

try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,
)

return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model or Reranking Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise ValueError(str(e))
except Exception as e:
logging.exception("Hit testing failed.")
raise InternalServerError(str(e))


api.add_resource(HitTestingApi, "/datasets2/<uuid:dataset_id>/hit-testing")

0 comments on commit abdf444

Please sign in to comment.